From 45eda1e4d49283a549e6ab81989ad56d2eeba831 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Wed, 4 Feb 2026 14:09:35 +0100 Subject: [PATCH 01/18] cleaned up all route and main references - no direct access to db.getRecordset - only over interfaces --- app.py | 7 - modules/aicore/aicorePluginAnthropic.py | 4 +- modules/aicore/aicorePluginInternal.py | 6 +- modules/aicore/aicorePluginOpenai.py | 8 +- modules/aicore/aicorePluginPerplexity.py | 4 +- modules/aicore/aicorePluginTavily.py | 2 +- modules/datamodels/datamodelAi.py | 4 +- modules/datamodels/datamodelChat.py | 4 +- modules/features/automation/mainAutomation.py | 140 ++++ modules/features/chatplayground/__init__.py | 6 + .../interfaceFeatureChatplayground.py | 145 ++++ .../chatplayground/mainChatplayground.py | 273 +++++++ .../routeFeatureChatplayground.py | 233 ++++++ .../mainNeutralizePlayground.py | 12 +- modules/features/realEstate/mainRealEstate.py | 56 +- modules/features/trustee/mainTrustee.py | 69 +- .../features/trustee/routeFeatureTrustee.py | 69 +- modules/interfaces/interfaceAiObjects.py | 10 +- modules/interfaces/interfaceBootstrap.py | 65 ++ modules/interfaces/interfaceDbApp.py | 771 ++++++++++++++++++ modules/routes/routeAdminFeatures.py | 215 ++--- modules/routes/routeAdminRbacExport.py | 129 ++- modules/routes/routeAdminRbacRoles.py | 104 +-- modules/routes/routeAdminRbacRules.py | 43 +- .../routes/routeAdminUserAccessOverview.py | 170 ++-- modules/routes/routeChat.py | 128 --- modules/routes/routeDataConnections.py | 22 +- modules/routes/routeDataMandates.py | 107 +-- modules/routes/routeDataUsers.py | 52 +- modules/routes/routeDataWorkflows.py | 51 +- modules/routes/routeGdpr.py | 118 +-- modules/routes/routeInvitations.py | 174 ++-- modules/routes/routeMessaging.py | 7 +- modules/routes/routeNotifications.py | 177 ++-- modules/routes/routeSecurityAdmin.py | 12 +- modules/routes/routeSecurityGoogle.py | 32 +- modules/routes/routeSecurityLocal.py | 28 +- modules/routes/routeSecurityMsft.py | 11 +- modules/routes/routeSystem.py | 85 +- modules/security/rbac.py | 26 +- .../services/serviceChat/mainServiceChat.py | 2 +- .../mainServiceExtraction.py | 22 +- modules/shared/gdprDeletion.py | 18 +- modules/system/mainSystem.py | 153 ++-- .../methods/methodAi/actions/process.py | 2 +- tests/functional/test02_ai_models.py | 2 +- 46 files changed, 2462 insertions(+), 1316 deletions(-) create mode 100644 modules/features/chatplayground/__init__.py create mode 100644 modules/features/chatplayground/interfaceFeatureChatplayground.py create mode 100644 modules/features/chatplayground/mainChatplayground.py create mode 100644 modules/features/chatplayground/routeFeatureChatplayground.py delete mode 100644 modules/routes/routeChat.py diff --git a/app.py b/app.py index ce3b3ff1..474de4d6 100644 --- a/app.py +++ b/app.py @@ -485,7 +485,6 @@ app.include_router(rbacAdminRulesRouter) from modules.routes.routeMessaging import router as messagingRouter app.include_router(messagingRouter) -# Phase 8: New Feature Routes from modules.routes.routeAdminFeatures import router as featuresAdminRouter app.include_router(featuresAdminRouter) @@ -504,12 +503,6 @@ app.include_router(userAccessOverviewRouter) from modules.routes.routeGdpr import router as gdprRouter app.include_router(gdprRouter) -from modules.routes.routeChat import router as chatRouter -app.include_router(chatRouter) - -from modules.features.chatbot.routeFeatureChatbot import router as chatbotFeatureRouter -app.include_router(chatbotFeatureRouter) - # ============================================================================ # SYSTEM ROUTES (Navigation, etc.) # ============================================================================ diff --git a/modules/aicore/aicorePluginAnthropic.py b/modules/aicore/aicorePluginAnthropic.py index 0d80aeaa..eeea9a07 100644 --- a/modules/aicore/aicorePluginAnthropic.py +++ b/modules/aicore/aicorePluginAnthropic.py @@ -72,7 +72,7 @@ class AiAnthropic(BaseConnectorAi): (OperationTypeEnum.DATA_EXTRACT, 8) ), version="claude-sonnet-4-5-20250929", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075 ), AiModel( name="claude-sonnet-4-5-20250929", @@ -93,7 +93,7 @@ class AiAnthropic(BaseConnectorAi): (OperationTypeEnum.IMAGE_ANALYSE, 10) ), version="claude-sonnet-4-5-20250929", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075 ) ] diff --git a/modules/aicore/aicorePluginInternal.py b/modules/aicore/aicorePluginInternal.py index 1b73c27e..59854629 100644 --- a/modules/aicore/aicorePluginInternal.py +++ b/modules/aicore/aicorePluginInternal.py @@ -40,7 +40,7 @@ class AiInternal(BaseConnectorAi): processingMode=ProcessingModeEnum.BASIC, operationTypes=createOperationTypeRatings(), version="internal-extractor-v1", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.001 + (bytesSent + bytesReceived) / (1024 * 1024) * 0.01 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.001 + (bytesSent + bytesReceived) / (1024 * 1024) * 0.01 ), AiModel( name="internal-generator", @@ -60,7 +60,7 @@ class AiInternal(BaseConnectorAi): processingMode=ProcessingModeEnum.BASIC, operationTypes=createOperationTypeRatings(), version="internal-generator-v1", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.002 + (bytesReceived / (1024 * 1024)) * 0.005 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.002 + (bytesReceived / (1024 * 1024)) * 0.005 ), AiModel( name="internal-renderer", @@ -80,7 +80,7 @@ class AiInternal(BaseConnectorAi): processingMode=ProcessingModeEnum.DETAILED, operationTypes=createOperationTypeRatings(), version="internal-renderer-v1", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.003 + (bytesReceived / (1024 * 1024)) * 0.008 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.003 + (bytesReceived / (1024 * 1024)) * 0.008 ) ] diff --git a/modules/aicore/aicorePluginOpenai.py b/modules/aicore/aicorePluginOpenai.py index c35c6dd6..931ece10 100644 --- a/modules/aicore/aicorePluginOpenai.py +++ b/modules/aicore/aicorePluginOpenai.py @@ -72,7 +72,7 @@ class AiOpenai(BaseConnectorAi): (OperationTypeEnum.DATA_EXTRACT, 7) ), version="gpt-4o", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06 ), AiModel( name="gpt-3.5-turbo", @@ -97,7 +97,7 @@ class AiOpenai(BaseConnectorAi): # Note: GPT-3.5-turbo does NOT support vision/image operations ), version="gpt-3.5-turbo", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0015 + (bytesReceived / 4 / 1000) * 0.002 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0015 + (bytesReceived / 4 / 1000) * 0.002 ), AiModel( name="gpt-4o", @@ -118,7 +118,7 @@ class AiOpenai(BaseConnectorAi): (OperationTypeEnum.IMAGE_ANALYSE, 9) ), version="gpt-4o", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06 ), AiModel( name="dall-e-3", @@ -140,7 +140,7 @@ class AiOpenai(BaseConnectorAi): (OperationTypeEnum.IMAGE_GENERATE, 10) ), version="dall-e-3", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04 ) ] diff --git a/modules/aicore/aicorePluginPerplexity.py b/modules/aicore/aicorePluginPerplexity.py index e6d1ba10..7cb5e928 100644 --- a/modules/aicore/aicorePluginPerplexity.py +++ b/modules/aicore/aicorePluginPerplexity.py @@ -74,7 +74,7 @@ class AiPerplexity(BaseConnectorAi): (OperationTypeEnum.WEB_CRAWL, 7) ), version="sonar", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.005 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.005 ), AiModel( name="sonar-pro", @@ -97,7 +97,7 @@ class AiPerplexity(BaseConnectorAi): (OperationTypeEnum.WEB_CRAWL, 8) ), version="sonar-pro", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.01 + (bytesReceived / 4 / 1000) * 0.01 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.01 + (bytesReceived / 4 / 1000) * 0.01 ) ] diff --git a/modules/aicore/aicorePluginTavily.py b/modules/aicore/aicorePluginTavily.py index 1d2ece75..635cd4eb 100644 --- a/modules/aicore/aicorePluginTavily.py +++ b/modules/aicore/aicorePluginTavily.py @@ -71,7 +71,7 @@ class AiTavily(BaseConnectorAi): (OperationTypeEnum.WEB_CRAWL, 10) ), version="tavily-search", - calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.008 # Simple flat rate + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: 0.008 # Simple flat rate ) ] diff --git a/modules/datamodels/datamodelAi.py b/modules/datamodels/datamodelAi.py index 69d51871..c9d81bfa 100644 --- a/modules/datamodels/datamodelAi.py +++ b/modules/datamodels/datamodelAi.py @@ -98,7 +98,7 @@ class AiModel(BaseModel): # Function reference (not serialized) functionCall: Optional[Callable] = Field(default=None, exclude=True, description="Function to call for this model") - calculatePriceUsd: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD") + calculatepriceCHF: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD") # Selection criteria - capabilities with ratings priority: PriorityEnum = Field(default=PriorityEnum.BALANCED, description="Default priority for this model. See PriorityEnum for available values.") @@ -159,7 +159,7 @@ class AiCallResponse(BaseModel): content: str = Field(description="AI response content") modelName: str = Field(description="Selected model name") - priceUsd: float = Field(default=0.0, description="Calculated price in USD") + priceCHF: float = Field(default=0.0, description="Calculated price in USD") processingTime: float = Field(default=0.0, description="Duration in seconds") bytesSent: int = Field(default=0, description="Input data size in bytes") bytesReceived: int = Field(default=0, description="Output data size in bytes") diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index 22c07aa2..8ba3ced1 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -26,7 +26,7 @@ class ChatStat(BaseModel): errorCount: Optional[int] = Field(None, description="Number of errors encountered") process: Optional[str] = Field(None, description="The process that delivers the stats data (e.g. 'action.outlook.readMails', 'ai.process.document.name')") engine: Optional[str] = Field(None, description="The engine used (e.g. 'ai.anthropic.35', 'ai.tavily.basic', 'renderer.docx')") - priceUsd: Optional[float] = Field(None, description="Calculated price in USD for the operation") + priceCHF: Optional[float] = Field(None, description="Calculated price in USD for the operation") registerModelLabels( @@ -41,7 +41,7 @@ registerModelLabels( "errorCount": {"en": "Error Count", "fr": "Nombre d'erreurs"}, "process": {"en": "Process", "fr": "Processus"}, "engine": {"en": "Engine", "fr": "Moteur"}, - "priceUsd": {"en": "Price USD", "fr": "Prix USD"}, + "priceCHF": {"en": "Price USD", "fr": "Prix USD"}, }, ) diff --git a/modules/features/automation/mainAutomation.py b/modules/features/automation/mainAutomation.py index 88828442..2b8443a9 100644 --- a/modules/features/automation/mainAutomation.py +++ b/modules/features/automation/mainAutomation.py @@ -59,7 +59,24 @@ RESOURCE_OBJECTS = [ ] # Template roles for this feature +# IMPORTANT: "viewer" role is required for automatic user assignment! TEMPLATE_ROLES = [ + { + "roleLabel": "viewer", + "description": { + "en": "Automation Viewer - View automations and execution results", + "de": "Automatisierungs-Betrachter - Automatisierungen und Ausführungsergebnisse einsehen", + "fr": "Visualiseur automatisation - Consulter les automatisations et résultats" + }, + "accessRules": [ + # UI access to all views + {"context": "UI", "item": "ui.feature.automation.definitions", "view": True}, + {"context": "UI", "item": "ui.feature.automation.templates", "view": True}, + {"context": "UI", "item": "ui.feature.automation.logs", "view": True}, + # Read-only DATA access + {"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "n"}, + ] + }, { "roleLabel": "automation-admin", "description": { @@ -161,9 +178,132 @@ def registerFeature(catalogService) -> bool: meta=resObj.get("meta") ) + # Sync template roles to database + _syncTemplateRolesToDb() + logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects") return True except Exception as e: logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}") return False + + +def _syncTemplateRolesToDb() -> int: + """ + Sync template roles and their AccessRules to the database. + Creates global template roles (mandateId=None) if they don't exist. + + Returns: + Number of roles created/updated + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext + + rootInterface = getRootInterface() + + # Get existing template roles for this feature (Pydantic models) + existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE) + # Filter to template roles (mandateId is None) + templateRoles = [r for r in existingRoles if r.mandateId is None] + existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles} + + createdCount = 0 + for roleTemplate in TEMPLATE_ROLES: + roleLabel = roleTemplate["roleLabel"] + + if roleLabel in existingRoleLabels: + roleId = existingRoleLabels[roleLabel] + logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}") + + # Ensure AccessRules exist for this role + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) + else: + # Create new template role + newRole = Role( + roleLabel=roleLabel, + description=roleTemplate.get("description", {}), + featureCode=FEATURE_CODE, + mandateId=None, # Global template + featureInstanceId=None, + isSystemRole=False + ) + createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump()) + roleId = createdRole.get("id") + + # Create AccessRules for this role + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) + + logger.info(f"Created template role '{roleLabel}' with ID {roleId}") + createdCount += 1 + + if createdCount > 0: + logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles") + + return createdCount + + except Exception as e: + logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}") + return 0 + + +def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int: + """ + Ensure AccessRules exist for a role based on templates. + + Args: + rootInterface: Root interface instance + roleId: Role ID + ruleTemplates: List of rule templates + + Returns: + Number of rules created + """ + from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext + + # Get existing rules for this role (Pydantic models) + existingRules = rootInterface.getAccessRulesByRole(roleId) + + # Create a set of existing rule signatures to avoid duplicates + existingSignatures = set() + for rule in existingRules: + sig = (str(rule.context) if rule.context else None, rule.item) + existingSignatures.add(sig) + + createdCount = 0 + for template in ruleTemplates: + context = template.get("context", "UI") + item = template.get("item") + sig = (context, item) + + if sig in existingSignatures: + continue + + # Map context string to enum + if context == "UI": + contextEnum = AccessRuleContext.UI + elif context == "DATA": + contextEnum = AccessRuleContext.DATA + elif context == "RESOURCE": + contextEnum = AccessRuleContext.RESOURCE + else: + contextEnum = context + + newRule = AccessRule( + roleId=roleId, + context=contextEnum, + item=item, + view=template.get("view", False), + read=template.get("read"), + create=template.get("create"), + update=template.get("update"), + delete=template.get("delete"), + ) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) + createdCount += 1 + + if createdCount > 0: + logger.debug(f"Created {createdCount} AccessRules for role {roleId}") + + return createdCount diff --git a/modules/features/chatplayground/__init__.py b/modules/features/chatplayground/__init__.py new file mode 100644 index 00000000..4b2f2bd4 --- /dev/null +++ b/modules/features/chatplayground/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Chat Playground Feature Container. +Provides workflow-based chat playground functionality. +""" diff --git a/modules/features/chatplayground/interfaceFeatureChatplayground.py b/modules/features/chatplayground/interfaceFeatureChatplayground.py new file mode 100644 index 00000000..5a2548ba --- /dev/null +++ b/modules/features/chatplayground/interfaceFeatureChatplayground.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Chat Playground Feature Interface. +Wrapper around interfaceDbChat with feature instance context. +""" + +import logging +from typing import Dict, Any, List, Optional + +from modules.datamodels.datamodelUam import User +from modules.interfaces import interfaceDbChat + +logger = logging.getLogger(__name__) + +# Feature code constant +FEATURE_CODE = "chatplayground" + +# Singleton instances cache +_instances: Dict[str, "ChatPlaygroundObjects"] = {} + + +def getInterface(currentUser: User, mandateId: str = None, featureInstanceId: str = None) -> "ChatPlaygroundObjects": + """ + Factory function to get or create a ChatPlaygroundObjects instance. + Uses singleton pattern per user context. + + Args: + currentUser: Current user object + mandateId: Mandate ID + featureInstanceId: Feature instance ID + + Returns: + ChatPlaygroundObjects instance + """ + cacheKey = f"{currentUser.id}_{mandateId}_{featureInstanceId}" + + if cacheKey not in _instances: + _instances[cacheKey] = ChatPlaygroundObjects(currentUser, mandateId, featureInstanceId) + else: + # Update context if needed + _instances[cacheKey].setUserContext(currentUser, mandateId, featureInstanceId) + + return _instances[cacheKey] + + +class ChatPlaygroundObjects: + """ + Chat Playground feature interface. + Wraps the shared interfaceDbChat with feature instance context. + """ + + FEATURE_CODE = FEATURE_CODE + + def __init__(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None): + """ + Initialize the Chat Playground interface. + + Args: + currentUser: Current user object + mandateId: Mandate ID + featureInstanceId: Feature instance ID + """ + self.currentUser = currentUser + self.mandateId = mandateId + self.featureInstanceId = featureInstanceId + + # Get the underlying chat interface + self._chatInterface = interfaceDbChat.getInterface( + currentUser, + mandateId=mandateId, + featureInstanceId=featureInstanceId + ) + + def setUserContext(self, currentUser: User, mandateId: str = None, featureInstanceId: str = None): + """ + Update the user context. + + Args: + currentUser: Current user object + mandateId: Mandate ID + featureInstanceId: Feature instance ID + """ + self.currentUser = currentUser + self.mandateId = mandateId + self.featureInstanceId = featureInstanceId + + # Update underlying interface + self._chatInterface = interfaceDbChat.getInterface( + currentUser, + mandateId=mandateId, + featureInstanceId=featureInstanceId + ) + + # ========================================================================= + # Delegated methods from interfaceDbChat + # ========================================================================= + + def getWorkflow(self, workflowId: str) -> Optional[Dict[str, Any]]: + """Get a workflow by ID.""" + return self._chatInterface.getWorkflow(workflowId) + + def getWorkflows(self, pagination=None) -> Dict[str, Any]: + """Get all workflows with pagination.""" + return self._chatInterface.getWorkflows(pagination=pagination) + + def getUnifiedChatData(self, workflowId: str, afterTimestamp: float = None) -> Dict[str, Any]: + """Get unified chat data for a workflow.""" + return self._chatInterface.getUnifiedChatData(workflowId, afterTimestamp) + + def createWorkflow(self, workflow) -> Dict[str, Any]: + """Create a new workflow.""" + return self._chatInterface.createWorkflow(workflow) + + def updateWorkflow(self, workflowId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update a workflow.""" + return self._chatInterface.updateWorkflow(workflowId, updates) + + def deleteWorkflow(self, workflowId: str) -> bool: + """Delete a workflow.""" + return self._chatInterface.deleteWorkflow(workflowId) + + def getMessages(self, workflowId: str) -> List[Dict[str, Any]]: + """Get messages for a workflow.""" + return self._chatInterface.getMessages(workflowId) + + def createMessage(self, message) -> Dict[str, Any]: + """Create a new message.""" + return self._chatInterface.createMessage(message) + + def getLogs(self, workflowId: str) -> List[Dict[str, Any]]: + """Get logs for a workflow.""" + return self._chatInterface.getLogs(workflowId) + + def createLog(self, log) -> Dict[str, Any]: + """Create a new log entry.""" + return self._chatInterface.createLog(log) + + def getStats(self, workflowId: str) -> List[Dict[str, Any]]: + """Get stats for a workflow.""" + return self._chatInterface.getStats(workflowId) + + def createStat(self, stat) -> Dict[str, Any]: + """Create a new stat entry.""" + return self._chatInterface.createStat(stat) diff --git a/modules/features/chatplayground/mainChatplayground.py b/modules/features/chatplayground/mainChatplayground.py new file mode 100644 index 00000000..ed0e2868 --- /dev/null +++ b/modules/features/chatplayground/mainChatplayground.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Chat Playground Feature Container - Main Module. +Handles feature initialization and RBAC catalog registration. +""" + +import logging +from typing import Dict, List, Any + +logger = logging.getLogger(__name__) + +# Feature metadata +FEATURE_CODE = "chatplayground" +FEATURE_LABEL = {"en": "Chat Playground", "de": "Chat Playground", "fr": "Chat Playground"} +FEATURE_ICON = "mdi-message-text" + +# UI Objects for RBAC catalog +UI_OBJECTS = [ + { + "objectKey": "ui.feature.chatplayground.playground", + "label": {"en": "Playground", "de": "Playground", "fr": "Playground"}, + "meta": {"area": "playground"} + }, + { + "objectKey": "ui.feature.chatplayground.workflows", + "label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"}, + "meta": {"area": "workflows"} + }, +] + +# Resource Objects for RBAC catalog +RESOURCE_OBJECTS = [ + { + "objectKey": "resource.feature.chatplayground.start", + "label": {"en": "Start Workflow", "de": "Workflow starten", "fr": "Démarrer workflow"}, + "meta": {"endpoint": "/api/chatplayground/{instanceId}/start", "method": "POST"} + }, + { + "objectKey": "resource.feature.chatplayground.stop", + "label": {"en": "Stop Workflow", "de": "Workflow stoppen", "fr": "Arrêter workflow"}, + "meta": {"endpoint": "/api/chatplayground/{instanceId}/{workflowId}/stop", "method": "POST"} + }, + { + "objectKey": "resource.feature.chatplayground.chatData", + "label": {"en": "Get Chat Data", "de": "Chat-Daten abrufen", "fr": "Récupérer données chat"}, + "meta": {"endpoint": "/api/chatplayground/{instanceId}/{workflowId}/chatData", "method": "GET"} + }, +] + +# Template roles for this feature +# IMPORTANT: "viewer" role is required for automatic user assignment! +TEMPLATE_ROLES = [ + { + "roleLabel": "viewer", + "description": { + "en": "Chat Playground Viewer - View and use chat playground", + "de": "Chat Playground Betrachter - Chat Playground ansehen und nutzen", + "fr": "Visualiseur Chat Playground - Consulter et utiliser le chat playground" + }, + "accessRules": [ + # UI access to all views + {"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True}, + {"context": "UI", "item": "ui.feature.chatplayground.workflows", "view": True}, + # Resource access + {"context": "RESOURCE", "item": "resource.feature.chatplayground.start", "view": True}, + {"context": "RESOURCE", "item": "resource.feature.chatplayground.stop", "view": True}, + {"context": "RESOURCE", "item": "resource.feature.chatplayground.chatData", "view": True}, + # DATA access (own records) + {"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"}, + ] + }, + { + "roleLabel": "admin", + "description": { + "en": "Chat Playground Admin - Full access to chat playground", + "de": "Chat Playground Admin - Vollzugriff auf Chat Playground", + "fr": "Administrateur Chat Playground - Accès complet au chat playground" + }, + "accessRules": [ + # Full UI access + {"context": "UI", "item": None, "view": True}, + # Full resource access + {"context": "RESOURCE", "item": None, "view": True}, + # Full DATA access + {"context": "DATA", "item": None, "view": True, "read": "a", "create": "a", "update": "a", "delete": "a"}, + ] + }, +] + + +def getFeatureDefinition() -> Dict[str, Any]: + """Return the feature definition for registration.""" + return { + "code": FEATURE_CODE, + "label": FEATURE_LABEL, + "icon": FEATURE_ICON + } + + +def getUiObjects() -> List[Dict[str, Any]]: + """Return UI objects for RBAC catalog registration.""" + return UI_OBJECTS + + +def getResourceObjects() -> List[Dict[str, Any]]: + """Return resource objects for RBAC catalog registration.""" + return RESOURCE_OBJECTS + + +def getTemplateRoles() -> List[Dict[str, Any]]: + """Return template roles for this feature.""" + return TEMPLATE_ROLES + + +def registerFeature(catalogService) -> bool: + """ + Register this feature's RBAC objects in the catalog. + + Args: + catalogService: The RBAC catalog service instance + + Returns: + True if registration was successful + """ + try: + # Register UI objects + for uiObj in UI_OBJECTS: + catalogService.registerUiObject( + featureCode=FEATURE_CODE, + objectKey=uiObj["objectKey"], + label=uiObj["label"], + meta=uiObj.get("meta") + ) + + # Register Resource objects + for resObj in RESOURCE_OBJECTS: + catalogService.registerResourceObject( + featureCode=FEATURE_CODE, + objectKey=resObj["objectKey"], + label=resObj["label"], + meta=resObj.get("meta") + ) + + # Sync template roles to database + _syncTemplateRolesToDb() + + logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects") + return True + + except Exception as e: + logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}") + return False + + +def _syncTemplateRolesToDb() -> int: + """ + Sync template roles and their AccessRules to the database. + Creates global template roles (mandateId=None) if they don't exist. + + Returns: + Number of roles created/updated + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext + + rootInterface = getRootInterface() + + # Get existing template roles for this feature (Pydantic models) + existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE) + # Filter to template roles (mandateId is None) + templateRoles = [r for r in existingRoles if r.mandateId is None] + existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles} + + createdCount = 0 + for roleTemplate in TEMPLATE_ROLES: + roleLabel = roleTemplate["roleLabel"] + + if roleLabel in existingRoleLabels: + roleId = existingRoleLabels[roleLabel] + logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}") + + # Ensure AccessRules exist for this role + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) + else: + # Create new template role + newRole = Role( + roleLabel=roleLabel, + description=roleTemplate.get("description", {}), + featureCode=FEATURE_CODE, + mandateId=None, # Global template + featureInstanceId=None, + isSystemRole=False + ) + createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump()) + roleId = createdRole.get("id") + + # Create AccessRules for this role + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) + + logger.info(f"Created template role '{roleLabel}' with ID {roleId}") + createdCount += 1 + + if createdCount > 0: + logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles") + + return createdCount + + except Exception as e: + logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}") + return 0 + + +def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int: + """ + Ensure AccessRules exist for a role based on templates. + + Args: + rootInterface: Root interface instance + roleId: Role ID + ruleTemplates: List of rule templates + + Returns: + Number of rules created + """ + from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext + + # Get existing rules for this role (Pydantic models) + existingRules = rootInterface.getAccessRulesByRole(roleId) + + # Create a set of existing rule signatures to avoid duplicates + existingSignatures = set() + for rule in existingRules: + sig = (str(rule.context) if rule.context else None, rule.item) + existingSignatures.add(sig) + + createdCount = 0 + for template in ruleTemplates: + context = template.get("context", "UI") + item = template.get("item") + sig = (context, item) + + if sig in existingSignatures: + continue + + # Map context string to enum + if context == "UI": + contextEnum = AccessRuleContext.UI + elif context == "DATA": + contextEnum = AccessRuleContext.DATA + elif context == "RESOURCE": + contextEnum = AccessRuleContext.RESOURCE + else: + contextEnum = context + + newRule = AccessRule( + roleId=roleId, + context=contextEnum, + item=item, + view=template.get("view", False), + read=template.get("read"), + create=template.get("create"), + update=template.get("update"), + delete=template.get("delete"), + ) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) + createdCount += 1 + + if createdCount > 0: + logger.debug(f"Created {createdCount} AccessRules for role {roleId}") + + return createdCount diff --git a/modules/features/chatplayground/routeFeatureChatplayground.py b/modules/features/chatplayground/routeFeatureChatplayground.py new file mode 100644 index 00000000..6a76e70e --- /dev/null +++ b/modules/features/chatplayground/routeFeatureChatplayground.py @@ -0,0 +1,233 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Chat Playground Feature Routes. +Implements the endpoints for chat playground workflow management as a feature. +""" + +import logging +from typing import Optional, Dict, Any +from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request + +# Import auth modules +from modules.auth import limiter, getRequestContext, RequestContext + +# Import interfaces +from modules.interfaces import interfaceDbChat + +# Import models +from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum + +# Import workflow control functions +from modules.workflows.automation import chatStart, chatStop + +# Configure logger +logger = logging.getLogger(__name__) + +# Create router for chat playground feature endpoints +router = APIRouter( + prefix="/api/chatplayground", + tags=["Chat Playground Feature"], + responses={404: {"description": "Not found"}} +) + + +def _getServiceChat(context: RequestContext, featureInstanceId: str = None): + """Get chat interface with feature instance context.""" + return interfaceDbChat.getInterface( + context.user, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=featureInstanceId + ) + + +async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: + """ + Validate that user has access to the feature instance. + + Args: + instanceId: Feature instance ID + context: Request context + + Returns: + mandateId for the instance + + Raises: + HTTPException if access is denied + """ + from modules.interfaces.interfaceDbApp import getRootInterface + + rootInterface = getRootInterface() + + # Get feature instance (Pydantic model) + instance = rootInterface.getFeatureInstance(instanceId) + if not instance: + raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found") + + # Check user has access to this instance using interface method + featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId) + + if not featureAccess or not featureAccess.enabled: + raise HTTPException(status_code=403, detail="Access denied to this feature instance") + + return str(instance.mandateId) if instance.mandateId else None + + +# Workflow start endpoint +@router.post("/{instanceId}/start", response_model=ChatWorkflow) +@limiter.limit("120/minute") +async def start_workflow( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"), + workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"), + userInput: UserInputRequest = Body(...), + context: RequestContext = Depends(getRequestContext) +) -> ChatWorkflow: + """ + Starts a new workflow or continues an existing one. + + Args: + instanceId: Feature instance ID + workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution + """ + try: + # Validate access and get mandate ID + mandateId = await _validateInstanceAccess(instanceId, context) + + # Start or continue workflow + workflow = await chatStart( + context.user, + userInput, + workflowMode, + workflowId, + mandateId=mandateId, + featureInstanceId=instanceId + ) + + return workflow + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error in start_workflow: {str(e)}") + raise HTTPException( + status_code=500, + detail=str(e) + ) + + +# Stop workflow endpoint +@router.post("/{instanceId}/{workflowId}/stop", response_model=ChatWorkflow) +@limiter.limit("120/minute") +async def stop_workflow( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: str = Path(..., description="ID of the workflow to stop"), + context: RequestContext = Depends(getRequestContext) +) -> ChatWorkflow: + """Stops a running workflow.""" + try: + # Validate access and get mandate ID + mandateId = await _validateInstanceAccess(instanceId, context) + + # Stop workflow + workflow = await chatStop( + context.user, + workflowId, + mandateId=mandateId, + featureInstanceId=instanceId + ) + + return workflow + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error in stop_workflow: {str(e)}") + raise HTTPException( + status_code=500, + detail=str(e) + ) + + +# Unified Chat Data Endpoint for Polling +@router.get("/{instanceId}/{workflowId}/chatData") +@limiter.limit("120/minute") +async def get_workflow_chat_data( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: str = Path(..., description="ID of the workflow"), + afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"), + context: RequestContext = Depends(getRequestContext) +) -> Dict[str, Any]: + """ + Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer. + Returns all data types in chronological order based on _createdAt timestamp. + """ + try: + # Validate access + await _validateInstanceAccess(instanceId, context) + + # Get service with feature instance context + chatInterface = _getServiceChat(context, featureInstanceId=instanceId) + + # Verify workflow exists + workflow = chatInterface.getWorkflow(workflowId) + if not workflow: + raise HTTPException( + status_code=404, + detail=f"Workflow with ID {workflowId} not found" + ) + + # Get unified chat data + chatData = chatInterface.getUnifiedChatData(workflowId, afterTimestamp) + + return chatData + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Error getting unified chat data: {str(e)}" + ) + + +# Get workflows for this instance +@router.get("/{instanceId}/workflows") +@limiter.limit("120/minute") +async def get_workflows( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + page: int = Query(1, ge=1, description="Page number"), + pageSize: int = Query(20, ge=1, le=100, description="Items per page"), + context: RequestContext = Depends(getRequestContext) +) -> Dict[str, Any]: + """ + Get all workflows for this feature instance. + """ + try: + # Validate access + await _validateInstanceAccess(instanceId, context) + + # Get service with feature instance context + chatInterface = _getServiceChat(context, featureInstanceId=instanceId) + + # Get workflows with pagination + from modules.datamodels.datamodelPagination import PaginationParams + pagination = PaginationParams(page=page, pageSize=pageSize) + + result = chatInterface.getWorkflows(pagination=pagination) + + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting workflows: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Error getting workflows: {str(e)}" + ) diff --git a/modules/features/neutralization/mainNeutralizePlayground.py b/modules/features/neutralization/mainNeutralizePlayground.py index bf9aa087..159faf04 100644 --- a/modules/features/neutralization/mainNeutralizePlayground.py +++ b/modules/features/neutralization/mainNeutralizePlayground.py @@ -182,20 +182,18 @@ class SharepointProcessor: async def _getSharepointConnection(self, sharepointPath: str = None): try: - connections = self.services.interfaceDbApp.db.getRecordset( - UserConnection, - recordFilter={"userId": self.services.interfaceDbApp.userId} - ) - msftConnections = [c for c in connections if c.get('authority') == 'msft'] + # Use interface method to get user connections + connections = self.services.interfaceDbApp.getUserConnections(self.services.interfaceDbApp.userId) + msftConnections = [c for c in connections if c.authority == 'msft'] if not msftConnections: logger.warning('No Microsoft connections found for user') return None if len(msftConnections) == 1: - logger.info(f"Found single Microsoft connection: {msftConnections[0].get('id')}") + logger.info(f"Found single Microsoft connection: {msftConnections[0].id}") return msftConnections[0] if sharepointPath: return await self._matchConnectionToPath(msftConnections, sharepointPath) - logger.info(f"Multiple Microsoft connections found, using first one: {msftConnections[0].get('id')}") + logger.info(f"Multiple Microsoft connections found, using first one: {msftConnections[0].id}") return msftConnections[0] except Exception: logger.error('Error getting SharePoint connection') diff --git a/modules/features/realEstate/mainRealEstate.py b/modules/features/realEstate/mainRealEstate.py index 5e43ceab..8562f5b8 100644 --- a/modules/features/realEstate/mainRealEstate.py +++ b/modules/features/realEstate/mainRealEstate.py @@ -165,13 +165,11 @@ def _syncTemplateRolesToDb() -> int: from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext rootInterface = getRootInterface() - db = rootInterface.db - existingRoles = db.getRecordset( - Role, - recordFilter={"featureCode": FEATURE_CODE, "mandateId": None} - ) - existingRoleLabels = {r.get("roleLabel"): r.get("id") for r in existingRoles} + # Get existing template roles (Pydantic models) + existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE) + templateRoles = [r for r in existingRoles if r.mandateId is None] + existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles} createdCount = 0 for roleTemplate in TEMPLATE_ROLES: @@ -179,7 +177,7 @@ def _syncTemplateRolesToDb() -> int: if roleLabel in existingRoleLabels: roleId = existingRoleLabels[roleLabel] - _ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", [])) + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) else: newRole = Role( roleLabel=roleLabel, @@ -189,65 +187,65 @@ def _syncTemplateRolesToDb() -> int: featureInstanceId=None, isSystemRole=False ) - createdRole = db.recordCreate(Role, newRole.model_dump()) + createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump()) roleId = createdRole.get("id") existingRoleLabels[roleLabel] = roleId - _ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", [])) + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) logging.getLogger(__name__).info(f"Created template role '{roleLabel}' with ID {roleId}") createdCount += 1 if createdCount > 0: logging.getLogger(__name__).info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles") - _repairInstanceRolesAccessRules(db, existingRoleLabels) + _repairInstanceRolesAccessRules(rootInterface, existingRoleLabels) return createdCount except Exception as e: logging.getLogger(__name__).error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}") return 0 -def _repairInstanceRolesAccessRules(db, templateRoleLabels: dict) -> int: +def _repairInstanceRolesAccessRules(rootInterface, templateRoleLabels: dict) -> int: """Repair instance-specific roles by copying AccessRules from their template roles.""" from modules.datamodels.datamodelRbac import Role, AccessRule repairedCount = 0 - allRoles = db.getRecordset(Role, recordFilter={"featureCode": FEATURE_CODE}) - instanceRoles = [r for r in allRoles if r.get("mandateId") is not None] + allRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE) + instanceRoles = [r for r in allRoles if r.mandateId is not None] for instanceRole in instanceRoles: - roleLabel = instanceRole.get("roleLabel") - instanceRoleId = instanceRole.get("id") + roleLabel = instanceRole.roleLabel + instanceRoleId = str(instanceRole.id) templateRoleId = templateRoleLabels.get(roleLabel) if not templateRoleId: continue - existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": instanceRoleId}) + existingRules = rootInterface.getAccessRulesByRole(instanceRoleId) if existingRules: continue - templateRules = db.getRecordset(AccessRule, recordFilter={"roleId": templateRoleId}) + templateRules = rootInterface.getAccessRulesByRole(templateRoleId) if not templateRules: continue for rule in templateRules: newRule = AccessRule( roleId=instanceRoleId, - context=rule.get("context"), - item=rule.get("item"), - view=rule.get("view", False), - read=rule.get("read"), - create=rule.get("create"), - update=rule.get("update"), - delete=rule.get("delete"), + context=rule.context, + item=rule.item, + view=rule.view if rule.view else False, + read=rule.read, + create=rule.create, + update=rule.update, + delete=rule.delete, ) - db.recordCreate(AccessRule, newRule.model_dump()) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) repairedCount += 1 return repairedCount -def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: list) -> int: +def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: list) -> int: """Ensure AccessRules exist for a role based on templates.""" from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext - existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": roleId}) - existingSignatures = {(r.get("context"), r.get("item")) for r in existingRules} + existingRules = rootInterface.getAccessRulesByRole(roleId) + existingSignatures = {(str(r.context) if r.context else None, r.item) for r in existingRules} createdCount = 0 for template in ruleTemplates or []: @@ -273,7 +271,7 @@ def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: list) -> int: update=template.get("update"), delete=template.get("delete"), ) - db.recordCreate(AccessRule, newRule.model_dump()) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) createdCount += 1 existingSignatures.add((context, item)) return createdCount diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py index 4f1694b5..ad449d8f 100644 --- a/modules/features/trustee/mainTrustee.py +++ b/modules/features/trustee/mainTrustee.py @@ -267,14 +267,11 @@ def _syncTemplateRolesToDb() -> int: from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext rootInterface = getRootInterface() - db = rootInterface.db - # Get existing template roles for this feature - existingRoles = db.getRecordset( - Role, - recordFilter={"featureCode": FEATURE_CODE, "mandateId": None} - ) - existingRoleLabels = {r.get("roleLabel"): r.get("id") for r in existingRoles} + # Get existing template roles for this feature (Pydantic models) + existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE) + templateRoles = [r for r in existingRoles if r.mandateId is None] + existingRoleLabels = {r.roleLabel: str(r.id) for r in templateRoles} createdCount = 0 for roleTemplate in TEMPLATE_ROLES: @@ -285,7 +282,7 @@ def _syncTemplateRolesToDb() -> int: logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}") # Ensure AccessRules exist for this role - _ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", [])) + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) else: # Create new template role newRole = Role( @@ -296,11 +293,11 @@ def _syncTemplateRolesToDb() -> int: featureInstanceId=None, isSystemRole=False ) - createdRole = db.recordCreate(Role, newRole.model_dump()) + createdRole = rootInterface.db.recordCreate(Role, newRole.model_dump()) roleId = createdRole.get("id") # Create AccessRules for this role - _ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", [])) + _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) logger.info(f"Created template role '{roleLabel}' with ID {roleId}") createdCount += 1 @@ -309,7 +306,7 @@ def _syncTemplateRolesToDb() -> int: logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles") # Repair instance-specific roles that are missing AccessRules - _repairInstanceRolesAccessRules(db, existingRoleLabels) + _repairInstanceRolesAccessRules(rootInterface, existingRoleLabels) return createdCount @@ -318,13 +315,13 @@ def _syncTemplateRolesToDb() -> int: return 0 -def _repairInstanceRolesAccessRules(db, templateRoleLabels: Dict[str, str]) -> int: +def _repairInstanceRolesAccessRules(rootInterface, templateRoleLabels: Dict[str, str]) -> int: """ Repair instance-specific roles by copying AccessRules from their template roles. This ensures instance roles created before AccessRules were defined get updated. Args: - db: Database connector + rootInterface: Root interface instance templateRoleLabels: Dict mapping roleLabel to template role ID Returns: @@ -334,41 +331,41 @@ def _repairInstanceRolesAccessRules(db, templateRoleLabels: Dict[str, str]) -> i repairedCount = 0 - # Get all instance-specific roles for this feature (mandateId is NOT None) - allRoles = db.getRecordset(Role, recordFilter={"featureCode": FEATURE_CODE}) - instanceRoles = [r for r in allRoles if r.get("mandateId") is not None] + # Get all instance-specific roles for this feature (Pydantic models) + allRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE) + instanceRoles = [r for r in allRoles if r.mandateId is not None] for instanceRole in instanceRoles: - roleLabel = instanceRole.get("roleLabel") - instanceRoleId = instanceRole.get("id") + roleLabel = instanceRole.roleLabel + instanceRoleId = str(instanceRole.id) # Find matching template role templateRoleId = templateRoleLabels.get(roleLabel) if not templateRoleId: continue - # Check if instance role has AccessRules - existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": instanceRoleId}) + # Check if instance role has AccessRules (Pydantic models) + existingRules = rootInterface.getAccessRulesByRole(instanceRoleId) if existingRules: continue # Already has rules, skip - # Copy AccessRules from template role - templateRules = db.getRecordset(AccessRule, recordFilter={"roleId": templateRoleId}) + # Copy AccessRules from template role (Pydantic models) + templateRules = rootInterface.getAccessRulesByRole(templateRoleId) if not templateRules: continue # Template has no rules for rule in templateRules: newRule = AccessRule( roleId=instanceRoleId, - context=rule.get("context"), - item=rule.get("item"), - view=rule.get("view", False), - read=rule.get("read"), - create=rule.get("create"), - update=rule.get("update"), - delete=rule.get("delete"), + context=rule.context, + item=rule.item, + view=rule.view if rule.view else False, + read=rule.read, + create=rule.create, + update=rule.update, + delete=rule.delete, ) - db.recordCreate(AccessRule, newRule.model_dump()) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) logger.info(f"Repaired instance role '{roleLabel}' (ID: {instanceRoleId}): copied {len(templateRules)} AccessRules from template") repairedCount += 1 @@ -379,12 +376,12 @@ def _repairInstanceRolesAccessRules(db, templateRoleLabels: Dict[str, str]) -> i return repairedCount -def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int: +def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int: """ Ensure AccessRules exist for a role based on templates. Args: - db: Database connector + rootInterface: Root interface instance roleId: Role ID ruleTemplates: List of rule templates @@ -393,13 +390,13 @@ def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: List[Dict[str, Any """ from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext - # Get existing rules for this role - existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": roleId}) + # Get existing rules for this role (Pydantic models) + existingRules = rootInterface.getAccessRulesByRole(roleId) # Create a set of existing rule signatures to avoid duplicates existingSignatures = set() for rule in existingRules: - sig = (rule.get("context"), rule.get("item")) + sig = (str(rule.context) if rule.context else None, rule.item) existingSignatures.add(sig) createdCount = 0 @@ -431,7 +428,7 @@ def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: List[Dict[str, Any update=template.get("update"), delete=template.get("delete"), ) - db.recordCreate(AccessRule, newRule.model_dump()) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) createdCount += 1 if createdCount > 0: diff --git a/modules/features/trustee/routeFeatureTrustee.py b/modules/features/trustee/routeFeatureTrustee.py index 9b1b1fca..43706a10 100644 --- a/modules/features/trustee/routeFeatureTrustee.py +++ b/modules/features/trustee/routeFeatureTrustee.py @@ -1363,17 +1363,11 @@ async def get_instance_roles( rootInterface = getRootInterface() - # Get instance-specific roles (mandateId set, featureInstanceId matches) - roles = rootInterface.db.getRecordset( - Role, - recordFilter={ - "featureCode": "trustee", - "featureInstanceId": instanceId - } - ) + # Get instance-specific roles (Pydantic models) + roles = rootInterface.getRolesByFeatureCode("trustee", featureInstanceId=instanceId) return PaginatedResponse( - items=roles, + items=[r.model_dump() for r in roles], pagination=None ) @@ -1390,18 +1384,16 @@ async def get_instance_role( mandateId = await _validateInstanceAdmin(instanceId, context) rootInterface = getRootInterface() - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) + role = rootInterface.getRole(roleId) - if not roles: + if not role: raise HTTPException(status_code=404, detail=f"Role {roleId} not found") - role = roles[0] - # Verify role belongs to this instance - if role.get("featureInstanceId") != instanceId: + if str(role.featureInstanceId) != instanceId: raise HTTPException(status_code=404, detail=f"Role {roleId} not found in this instance") - return role + return role.model_dump() @router.get("/{instanceId}/instance-roles/{roleId}/rules", response_model=PaginatedResponse) @@ -1420,19 +1412,16 @@ async def get_instance_role_rules( rootInterface = getRootInterface() - # Verify role belongs to this instance - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if not roles or roles[0].get("featureInstanceId") != instanceId: + # Verify role belongs to this instance (Pydantic model) + role = rootInterface.getRole(roleId) + if not role or str(role.featureInstanceId) != instanceId: raise HTTPException(status_code=404, detail=f"Role {roleId} not found in this instance") - # Get AccessRules for this role - rules = rootInterface.db.getRecordset( - AccessRule, - recordFilter={"roleId": roleId} - ) + # Get AccessRules for this role (Pydantic models) + rules = rootInterface.getAccessRulesByRole(roleId) return PaginatedResponse( - items=rules, + items=[r.model_dump() for r in rules], pagination=None ) @@ -1454,9 +1443,9 @@ async def create_instance_role_rule( rootInterface = getRootInterface() - # Verify role belongs to this instance - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if not roles or roles[0].get("featureInstanceId") != instanceId: + # Verify role belongs to this instance (Pydantic model) + role = rootInterface.getRole(roleId) + if not role or str(role.featureInstanceId) != instanceId: raise HTTPException(status_code=404, detail=f"Role {roleId} not found in this instance") # Create the rule @@ -1505,14 +1494,14 @@ async def update_instance_role_rule( rootInterface = getRootInterface() - # Verify role belongs to this instance - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if not roles or roles[0].get("featureInstanceId") != instanceId: + # Verify role belongs to this instance (Pydantic model) + role = rootInterface.getRole(roleId) + if not role or str(role.featureInstanceId) != instanceId: raise HTTPException(status_code=404, detail=f"Role {roleId} not found in this instance") - # Verify rule belongs to role - existingRules = rootInterface.db.getRecordset(AccessRule, recordFilter={"id": ruleId}) - if not existingRules or existingRules[0].get("roleId") != roleId: + # Verify rule belongs to role (Pydantic model) + existingRule = rootInterface.getAccessRule(ruleId) + if not existingRule or str(existingRule.roleId) != roleId: raise HTTPException(status_code=404, detail=f"Rule {ruleId} not found for this role") # Update only allowed fields @@ -1529,7 +1518,7 @@ async def update_instance_role_rule( updateData["delete"] = ruleData["delete"] if not updateData: - return existingRules[0] + return existingRule.model_dump() try: updated = rootInterface.db.recordModify(AccessRule, ruleId, updateData) @@ -1556,14 +1545,14 @@ async def delete_instance_role_rule( rootInterface = getRootInterface() - # Verify role belongs to this instance - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if not roles or roles[0].get("featureInstanceId") != instanceId: + # Verify role belongs to this instance (Pydantic model) + role = rootInterface.getRole(roleId) + if not role or str(role.featureInstanceId) != instanceId: raise HTTPException(status_code=404, detail=f"Role {roleId} not found in this instance") - # Verify rule belongs to role - existingRules = rootInterface.db.getRecordset(AccessRule, recordFilter={"id": ruleId}) - if not existingRules or existingRules[0].get("roleId") != roleId: + # Verify rule belongs to role (Pydantic model) + existingRule = rootInterface.getAccessRule(ruleId) + if not existingRule or str(existingRule.roleId) != roleId: raise HTTPException(status_code=404, detail=f"Rule {ruleId} not found for this role") try: diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py index 5c252ff6..2e6e36f5 100644 --- a/modules/interfaces/interfaceAiObjects.py +++ b/modules/interfaces/interfaceAiObjects.py @@ -97,7 +97,7 @@ class AiObjects: return AiCallResponse( content=errorMsg, modelName="error", - priceUsd=0.0, + priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, @@ -135,7 +135,7 @@ class AiObjects: return AiCallResponse( content=errorMsg, modelName="error", - priceUsd=0.0, + priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, @@ -147,7 +147,7 @@ class AiObjects: return AiCallResponse( content=errorMsg, modelName="error", - priceUsd=0.0, + priceCHF=0.0, processingTime=0.0, bytesSent=inputBytes, bytesReceived=outputBytes, @@ -213,12 +213,12 @@ class AiObjects: outputBytes = len(content.encode("utf-8")) # Calculate price using model's own price calculation method - priceUsd = model.calculatePriceUsd(processingTime, inputBytes, outputBytes) + priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) return AiCallResponse( content=content, modelName=model.name, - priceUsd=priceUsd, + priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=outputBytes, diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index 3836d674..0b630f85 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -72,6 +72,10 @@ def initBootstrap(db: DatabaseConnector) -> None: # Seed automation templates (after admin user exists) initAutomationTemplates(db, adminUserId) + + # Initialize feature instances for root mandate + if mandateId: + initRootMandateFeatures(db, mandateId) def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] = None) -> None: @@ -153,6 +157,67 @@ def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] logger.info("System bootstrap completed") +def initRootMandateFeatures(db: DatabaseConnector, mandateId: str) -> None: + """ + Create feature instances for root mandate (chatplayground, automation). + These features are available to all users by default. + + Args: + db: Database connector instance + mandateId: Root mandate ID + """ + from modules.datamodels.datamodelFeatures import FeatureInstance + from modules.interfaces.interfaceFeatures import getFeatureInterface + + logger.info("Initializing root mandate features") + + # Features to create instances for + featuresToCreate = [ + {"code": "chatplayground", "label": "Chat Playground"}, + {"code": "automation", "label": "Automation"}, + ] + + featureInterface = getFeatureInterface(db) + + for featureConfig in featuresToCreate: + featureCode = featureConfig["code"] + featureLabel = featureConfig["label"] + + try: + # Check if instance already exists + existingInstances = db.getRecordset( + FeatureInstance, + recordFilter={ + "mandateId": mandateId, + "featureCode": featureCode + } + ) + + if existingInstances: + logger.info(f"Feature instance for '{featureCode}' already exists in root mandate") + continue + + # Create feature instance with template roles copied + instance = featureInterface.createFeatureInstance( + featureCode=featureCode, + mandateId=mandateId, + label=featureLabel, + enabled=True, + copyTemplateRoles=True + ) + + if instance: + instanceId = instance.get("id") if isinstance(instance, dict) else instance.id + logger.info(f"Created feature instance '{instanceId}' for '{featureCode}' in root mandate") + else: + logger.warning(f"Failed to create feature instance for '{featureCode}'") + + except Exception as e: + logger.error(f"Error creating feature instance for '{featureCode}': {e}") + + logger.info("Root mandate features initialization completed") + + def initRootMandate(db: DatabaseConnector) -> Optional[str]: """ Creates the Root mandate if it doesn't exist. diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 1f1d1e53..2a872bce 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -45,6 +45,7 @@ from modules.datamodels.datamodelMembership import ( ) from modules.datamodels.datamodelFeatures import Feature, FeatureInstance from modules.datamodels.datamodelInvitation import Invitation +from modules.datamodels.datamodelNotification import UserNotification logger = logging.getLogger(__name__) @@ -733,6 +734,9 @@ class AppObjects: # Clear cache to ensure fresh data (already done above) + # Grant access to root mandate features (chatplayground, automation) + self._grantRootMandateFeatureAccess(createdUser[0]["id"]) + return User(**createdUser[0]) except ValueError as e: @@ -796,6 +800,99 @@ class AppObjects: logger.error(f"Error updating user: {str(e)}") raise ValueError(f"Failed to update user: {str(e)}") + def _grantRootMandateFeatureAccess(self, userId: str) -> None: + """ + Grant a new user access to root mandate features (chatplayground, automation). + Creates FeatureAccess with viewer role for each feature instance. + + Args: + userId: User ID to grant access to + """ + try: + from modules.datamodels.datamodelFeatures import FeatureInstance + from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole + from modules.datamodels.datamodelRbac import Role + + # Get root mandate ID (first mandate in system) + allMandates = self.db.getRecordset(Mandate) + if not allMandates: + logger.debug("No mandates found, skipping feature access grant") + return + rootMandateId = allMandates[0].get("id") + + # Feature codes to grant access to + rootFeatureCodes = ["chatplayground", "automation"] + + # Get feature instances for root mandate + allInstances = self.db.getRecordset(FeatureInstance) + featureInstances = [ + inst for inst in allInstances + if inst.get("mandateId") == rootMandateId + and inst.get("featureCode") in rootFeatureCodes + and inst.get("enabled") == True + ] + + if not featureInstances: + logger.debug("No root mandate feature instances found, skipping feature access grant") + return + + # Grant access to each feature instance + for instance in featureInstances: + instanceId = instance.get("id") + featureCode = instance.get("featureCode") + + # Check if user already has access + existingAccess = self.db.getRecordset( + FeatureAccess, + recordFilter={ + "userId": userId, + "featureInstanceId": instanceId + } + ) + + if existingAccess: + logger.debug(f"User {userId} already has access to feature instance {instanceId}") + continue + + # Create FeatureAccess + featureAccess = FeatureAccess( + userId=userId, + featureInstanceId=instanceId, + enabled=True + ) + createdAccess = self.db.recordCreate(FeatureAccess, featureAccess.model_dump()) + + if not createdAccess: + logger.warning(f"Failed to create FeatureAccess for user {userId} to instance {instanceId}") + continue + + featureAccessId = createdAccess.get("id") + + # Get viewer role for this feature instance + allRoles = self.db.getRecordset(Role) + viewerRoles = [ + r for r in allRoles + if r.get("featureInstanceId") == instanceId + and r.get("roleLabel") == "viewer" + ] + + if viewerRoles: + # Create FeatureAccessRole junction + featureAccessRole = FeatureAccessRole( + featureAccessId=featureAccessId, + roleId=viewerRoles[0].get("id") + ) + self.db.recordCreate(FeatureAccessRole, featureAccessRole.model_dump()) + logger.debug(f"Granted viewer role for {featureCode} to user {userId}") + else: + logger.warning(f"No viewer role found for feature instance {instanceId} ({featureCode})") + + logger.info(f"Granted root mandate feature access to user {userId}") + + except Exception as e: + # Log but don't fail user creation + logger.error(f"Error granting root mandate feature access to user {userId}: {e}") + def disableUser(self, userId: str) -> User: """Disables a user if current user has permission.""" return self.updateUser(userId, {"enabled": False}) @@ -1209,6 +1306,31 @@ class AppObjects: logger.error(f"Error getting user connections: {str(e)}") return [] + def getUserConnectionById(self, connectionId: str) -> Optional[UserConnection]: + """Get a single UserConnection by ID.""" + try: + connections = self.db.getRecordset( + UserConnection, recordFilter={"id": connectionId} + ) + if connections: + conn_dict = connections[0] + return UserConnection( + id=conn_dict["id"], + userId=conn_dict["userId"], + authority=conn_dict.get("authority"), + externalId=conn_dict.get("externalId", ""), + externalUsername=conn_dict.get("externalUsername", ""), + externalEmail=conn_dict.get("externalEmail"), + status=conn_dict.get("status", "pending"), + connectedAt=conn_dict.get("connectedAt"), + lastChecked=conn_dict.get("lastChecked"), + expiresAt=conn_dict.get("expiresAt"), + ) + return None + except Exception as e: + logger.error(f"Error getting user connection by ID: {str(e)}") + return None + def addUserConnection( self, userId: str, @@ -1547,6 +1669,106 @@ class AppObjects: logger.error(f"Error deleting UserMandate: {e}") raise ValueError(f"Failed to delete UserMandate: {e}") + def getUserMandatesByMandate(self, mandateId: str) -> List[UserMandate]: + """ + Get all UserMandate records for a specific mandate. + + Args: + mandateId: Mandate ID + + Returns: + List of UserMandate objects + """ + try: + records = self.db.getRecordset( + UserMandate, + recordFilter={"mandateId": mandateId} + ) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(UserMandate(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting UserMandates for mandate {mandateId}: {e}") + return [] + + def getUserMandateRoles(self, userMandateId: str) -> List[UserMandateRole]: + """ + Get all UserMandateRole records for a UserMandate. + + Args: + userMandateId: UserMandate ID + + Returns: + List of UserMandateRole objects + """ + try: + records = self.db.getRecordset( + UserMandateRole, + recordFilter={"userMandateId": userMandateId} + ) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(UserMandateRole(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting UserMandateRoles: {e}") + return [] + + def deleteUserMandateRoles(self, userMandateId: str) -> int: + """ + Delete all role assignments for a UserMandate. + + Args: + userMandateId: UserMandate ID + + Returns: + Number of deleted role assignments + """ + try: + records = self.db.getRecordset( + UserMandateRole, + recordFilter={"userMandateId": userMandateId} + ) + deletedCount = 0 + for record in records: + if self.db.recordDelete(UserMandateRole, record.get("id")): + deletedCount += 1 + return deletedCount + except Exception as e: + logger.error(f"Error deleting UserMandateRoles: {e}") + return 0 + + def validateRoleForMandate(self, roleId: str, mandateId: str) -> Role: + """ + Validate a role exists and belongs to the specified mandate (or is global). + + Args: + roleId: Role ID to validate + mandateId: Mandate ID for context validation + + Returns: + Role object if valid + + Raises: + ValueError: If role not found or belongs to different mandate + """ + role = self.getRole(roleId) + if not role: + raise ValueError(f"Role {roleId} not found") + + # Check mandate scope + if role.mandateId and str(role.mandateId) != str(mandateId): + raise ValueError(f"Role {roleId} belongs to a different mandate") + + # Check feature-instance scope (not allowed at mandate level) + if role.featureInstanceId: + raise ValueError(f"Role {roleId} is a feature-instance role and cannot be assigned at mandate level") + + return role + def getRoleIdsForUserMandate(self, userMandateId: str) -> List[str]: """ Get all role IDs assigned to a UserMandate. @@ -1688,6 +1910,30 @@ class AppObjects: logger.error(f"Error getting FeatureAccesses: {e}") return [] + def getFeatureAccessesByInstance(self, featureInstanceId: str) -> List[FeatureAccess]: + """ + Get all FeatureAccess records for a specific feature instance. + + Args: + featureInstanceId: FeatureInstance ID + + Returns: + List of FeatureAccess objects + """ + try: + records = self.db.getRecordset( + FeatureAccess, + recordFilter={"featureInstanceId": featureInstanceId} + ) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(FeatureAccess(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting FeatureAccesses for instance {featureInstanceId}: {e}") + return [] + def createFeatureAccess(self, userId: str, featureInstanceId: str, roleIds: List[str] = None) -> FeatureAccess: """ Create a FeatureAccess record (grant user access to feature instance). @@ -1750,6 +1996,445 @@ class AppObjects: logger.error(f"Error getting role IDs for FeatureAccess: {e}") return [] + def deleteFeatureAccessRoles(self, featureAccessId: str) -> int: + """ + Delete all FeatureAccessRole records for a FeatureAccess. + + Args: + featureAccessId: FeatureAccess ID + + Returns: + Number of records deleted + """ + try: + records = self.db.getRecordset( + FeatureAccessRole, + recordFilter={"featureAccessId": featureAccessId} + ) + count = 0 + for record in records: + recordId = record.get("id") + if recordId: + self.db.recordDelete(FeatureAccessRole, recordId) + count += 1 + return count + except Exception as e: + logger.error(f"Error deleting FeatureAccessRoles for {featureAccessId}: {e}") + return 0 + + # ============================================ + # Invitation Methods + # ============================================ + + def getInvitation(self, invitationId: str) -> Optional[Invitation]: + """ + Get an invitation by ID. + + Args: + invitationId: Invitation ID + + Returns: + Invitation object if found, None otherwise + """ + try: + records = self.db.getRecordset(Invitation, recordFilter={"id": invitationId}) + if records: + cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")} + return Invitation(**cleanedRecord) + return None + except Exception as e: + logger.error(f"Error getting invitation {invitationId}: {e}") + return None + + def getInvitationByToken(self, token: str) -> Optional[Invitation]: + """ + Get an invitation by token. + + Args: + token: Invitation token + + Returns: + Invitation object if found, None otherwise + """ + try: + records = self.db.getRecordset(Invitation, recordFilter={"token": token}) + if records: + cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")} + return Invitation(**cleanedRecord) + return None + except Exception as e: + logger.error(f"Error getting invitation by token: {e}") + return None + + def getInvitationsByMandate(self, mandateId: str) -> List[Invitation]: + """ + Get all invitations for a mandate. + + Args: + mandateId: Mandate ID + + Returns: + List of Invitation objects + """ + try: + records = self.db.getRecordset(Invitation, recordFilter={"mandateId": mandateId}) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(Invitation(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting invitations for mandate {mandateId}: {e}") + return [] + + def getInvitationsByCreator(self, creatorId: str) -> List[Invitation]: + """ + Get all invitations created by a user. + + Args: + creatorId: User ID who created the invitations + + Returns: + List of Invitation objects + """ + try: + records = self.db.getRecordset(Invitation, recordFilter={"createdBy": creatorId}) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(Invitation(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting invitations by creator {creatorId}: {e}") + return [] + + def getInvitationsByUsedBy(self, usedById: str) -> List[Invitation]: + """ + Get all invitations used by a user. + + Args: + usedById: User ID who used the invitations + + Returns: + List of Invitation objects + """ + try: + records = self.db.getRecordset(Invitation, recordFilter={"usedBy": usedById}) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(Invitation(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting invitations used by {usedById}: {e}") + return [] + + def getInvitationsByTargetUsername(self, targetUsername: str) -> List[Invitation]: + """ + Get all invitations for a target username. + + Args: + targetUsername: Target username for the invitations + + Returns: + List of Invitation objects + """ + try: + records = self.db.getRecordset(Invitation, recordFilter={"targetUsername": targetUsername}) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(Invitation(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting invitations for target username {targetUsername}: {e}") + return [] + + # ============================================ + # Additional Helper Methods + # ============================================ + + def getAllUsers(self) -> List[User]: + """ + Get all users (for SysAdmin only). + + Returns: + List of User objects (without sensitive fields) + """ + try: + records = self.db.getRecordset(UserInDB) + result = [] + for record in records: + # Filter out sensitive and internal fields + cleanedRecord = { + k: v for k, v in record.items() + if not k.startswith("_") and k not in ["hashedPassword", "resetToken", "resetTokenExpires"] + } + # Ensure roleLabels is a list + if cleanedRecord.get("roleLabels") is None: + cleanedRecord["roleLabels"] = [] + result.append(User(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting all users: {e}") + return [] + + def getUserMandateById(self, userMandateId: str) -> Optional[UserMandate]: + """ + Get a UserMandate by its ID. + + Args: + userMandateId: UserMandate ID + + Returns: + UserMandate object if found, None otherwise + """ + try: + records = self.db.getRecordset(UserMandate, recordFilter={"id": userMandateId}) + if records: + cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")} + return UserMandate(**cleanedRecord) + return None + except Exception as e: + logger.error(f"Error getting UserMandate {userMandateId}: {e}") + return None + + def getUserMandateRolesByRole(self, roleId: str) -> List[UserMandateRole]: + """ + Get all UserMandateRole records for a specific role. + + Args: + roleId: Role ID + + Returns: + List of UserMandateRole objects + """ + try: + records = self.db.getRecordset(UserMandateRole, recordFilter={"roleId": roleId}) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(UserMandateRole(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting UserMandateRoles for role {roleId}: {e}") + return [] + + def getFeatureInstance(self, instanceId: str): + """ + Get a FeatureInstance by ID. + + Args: + instanceId: FeatureInstance ID + + Returns: + FeatureInstance object if found, None otherwise + """ + try: + records = self.db.getRecordset(FeatureInstance, recordFilter={"id": instanceId}) + if records: + cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")} + return FeatureInstance(**cleanedRecord) + return None + except Exception as e: + logger.error(f"Error getting FeatureInstance {instanceId}: {e}") + return None + + def getFeatureByCode(self, featureCode: str) -> Optional[Feature]: + """ + Get a Feature by its code. + + Args: + featureCode: Feature code + + Returns: + Feature object if found, None otherwise + """ + try: + records = self.db.getRecordset(Feature, recordFilter={"code": featureCode}) + if records: + cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")} + return Feature(**cleanedRecord) + return None + except Exception as e: + logger.error(f"Error getting Feature by code {featureCode}: {e}") + return None + + def getFeatureInstancesByMandate(self, mandateId: str, enabledOnly: bool = False) -> List[FeatureInstance]: + """ + Get all FeatureInstances for a mandate. + + Args: + mandateId: Mandate ID + enabledOnly: If True, only return enabled instances + + Returns: + List of FeatureInstance objects + """ + try: + recordFilter = {"mandateId": mandateId} + if enabledOnly: + recordFilter["enabled"] = True + records = self.db.getRecordset(FeatureInstance, recordFilter=recordFilter) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(FeatureInstance(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting FeatureInstances for mandate {mandateId}: {e}") + return [] + + # ============================================ + # Notification Methods + # ============================================ + + def getNotification(self, notificationId: str) -> Optional[UserNotification]: + """ + Get a notification by ID. + + Args: + notificationId: Notification ID + + Returns: + UserNotification object if found, None otherwise + """ + try: + records = self.db.getRecordset(UserNotification, recordFilter={"id": notificationId}) + if records: + cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")} + return UserNotification(**cleanedRecord) + return None + except Exception as e: + logger.error(f"Error getting notification {notificationId}: {e}") + return None + + def getNotificationsByUser( + self, + userId: str, + status: Optional[str] = None, + limit: Optional[int] = None + ) -> List[UserNotification]: + """ + Get notifications for a user. + + Args: + userId: User ID + status: Optional status filter (e.g., 'unread') + limit: Optional limit on number of results + + Returns: + List of UserNotification objects + """ + try: + recordFilter = {"userId": userId} + if status: + recordFilter["status"] = status + records = self.db.getRecordset(UserNotification, recordFilter=recordFilter) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(UserNotification(**cleanedRecord)) + # Sort by createdAt descending + result.sort(key=lambda x: x.createdAt or 0, reverse=True) + if limit: + result = result[:limit] + return result + except Exception as e: + logger.error(f"Error getting notifications for user {userId}: {e}") + return [] + + # ============================================ + # AccessRule Methods + # ============================================ + + def getAccessRule(self, ruleId: str) -> Optional[AccessRule]: + """ + Get an AccessRule by ID. + + Args: + ruleId: AccessRule ID + + Returns: + AccessRule object if found, None otherwise + """ + try: + records = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId}) + if records: + cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")} + return AccessRule(**cleanedRecord) + return None + except Exception as e: + logger.error(f"Error getting AccessRule {ruleId}: {e}") + return None + + def getAccessRulesByRole(self, roleId: str) -> List[AccessRule]: + """ + Get all AccessRules for a role. + + Args: + roleId: Role ID + + Returns: + List of AccessRule objects + """ + try: + records = self.db.getRecordset(AccessRule, recordFilter={"roleId": roleId}) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(AccessRule(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting AccessRules for role {roleId}: {e}") + return [] + + def getRolesByFeatureInstance(self, featureInstanceId: str) -> List[Role]: + """ + Get all roles for a feature instance. + + Args: + featureInstanceId: FeatureInstance ID + + Returns: + List of Role objects + """ + try: + records = self.db.getRecordset(Role, recordFilter={"featureInstanceId": featureInstanceId}) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(Role(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting roles for feature instance {featureInstanceId}: {e}") + return [] + + def getRolesByFeatureCode(self, featureCode: str, featureInstanceId: Optional[str] = None) -> List[Role]: + """ + Get all roles for a feature code, optionally filtered by instance. + + Args: + featureCode: Feature code + featureInstanceId: Optional FeatureInstance ID filter + + Returns: + List of Role objects + """ + try: + recordFilter = {"featureCode": featureCode} + if featureInstanceId: + recordFilter["featureInstanceId"] = featureInstanceId + records = self.db.getRecordset(Role, recordFilter=recordFilter) + result = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} + result.append(Role(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting roles for feature code {featureCode}: {e}") + return [] + # Token methods def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None: @@ -1908,6 +2593,56 @@ class AppObjects: ) return None + def getTokensByConnectionIdAndAuthority( + self, connectionId: str, authority: AuthAuthority + ) -> List[Token]: + """Get tokens for a connection with specific authority.""" + try: + tokens = self.db.getRecordset( + Token, recordFilter={ + "connectionId": connectionId, + "authority": authority.value if hasattr(authority, 'value') else str(authority) + } + ) + result = [] + for token_dict in tokens: + cleanedRecord = {k: v for k, v in token_dict.items() if not k.startswith("_")} + result.append(Token(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting tokens by connection and authority: {str(e)}") + return [] + + def getTokensByUserIdNoConnection( + self, userId: str, authority: AuthAuthority + ) -> List[Token]: + """Get tokens for a user without a connection (access tokens).""" + try: + tokens = self.db.getRecordset( + Token, recordFilter={ + "userId": userId, + "connectionId": None, + "authority": authority.value if hasattr(authority, 'value') else str(authority) + } + ) + result = [] + for token_dict in tokens: + cleanedRecord = {k: v for k, v in token_dict.items() if not k.startswith("_")} + result.append(Token(**cleanedRecord)) + return result + except Exception as e: + logger.error(f"Error getting tokens by user and authority: {str(e)}") + return [] + + def getAllTokens(self, recordFilter: dict = None) -> List[dict]: + """Get all tokens with optional filtering (returns raw dicts).""" + try: + tokens = self.db.getRecordset(Token, recordFilter=recordFilter or {}) + return tokens + except Exception as e: + logger.error(f"Error getting all tokens: {str(e)}") + return [] + def findActiveTokenById( self, tokenId: str, @@ -2340,6 +3075,42 @@ class AppObjects: logger.error(f"Error getting role by label {roleLabel}: {str(e)}") return None + def getRoleByLabelAndScope( + self, + roleLabel: str, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None, + featureCode: Optional[str] = None + ) -> Optional[Role]: + """ + Get a role by label with scope filtering. + + Args: + roleLabel: Role label + mandateId: Mandate ID (use None for global roles) + featureInstanceId: Feature instance ID + featureCode: Feature code + + Returns: + Role object if found, None otherwise + """ + try: + recordFilter = {"roleLabel": roleLabel} + if mandateId is not None: + recordFilter["mandateId"] = mandateId + if featureInstanceId is not None: + recordFilter["featureInstanceId"] = featureInstanceId + if featureCode is not None: + recordFilter["featureCode"] = featureCode + + roles = self.db.getRecordset(Role, recordFilter=recordFilter) + if roles: + return Role(**roles[0]) + return None + except Exception as e: + logger.error(f"Error getting role by label and scope {roleLabel}: {str(e)}") + return None + def getAllRoles(self, pagination: Optional[PaginationParams] = None) -> Union[List[Role], PaginatedResult]: """ Get all roles with optional pagination, sorting, and filtering. diff --git a/modules/routes/routeAdminFeatures.py b/modules/routes/routeAdminFeatures.py index 84d2bfcf..87582b9e 100644 --- a/modules/routes/routeAdminFeatures.py +++ b/modules/routes/routeAdminFeatures.py @@ -204,38 +204,26 @@ async def get_my_feature_instances( def _getUserRolesInInstance(rootInterface, userId: str, instanceId: str) -> List[str]: """Get all role labels for a user in a feature instance.""" try: - from modules.datamodels.datamodelRbac import Role - from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole + # Get FeatureAccess for this user and instance (Pydantic model) + featureAccess = rootInterface.getFeatureAccess(userId, instanceId) - # Get FeatureAccess for this user and instance - featureAccesses = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": userId, "featureInstanceId": instanceId} - ) - - if featureAccesses: - featureAccessId = featureAccesses[0].get("id") + if featureAccess: + # Get role IDs via interface method + roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id)) - # Get role IDs via FeatureAccessRole junction table - featureAccessRoles = rootInterface.db.getRecordset( - FeatureAccessRole, - recordFilter={"featureAccessId": featureAccessId} - ) - - if featureAccessRoles: - # Get ALL roles, not just the first one + if roleIds: + # Get ALL roles and extract labels roleLabels = [] - for far in featureAccessRoles: - roleId = far.get("roleId") - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roles: - roleLabels.append(roles[0].get("roleLabel", "user")) + for roleId in roleIds: + role = rootInterface.getRole(roleId) + if role: + roleLabels.append(role.roleLabel) return roleLabels if roleLabels else ["user"] - return ["user"] # Default + return ["user"] # Default - no access means basic user level except Exception as e: logger.debug(f"Error getting user roles: {e}") - return ["user"] + return ["user"] # Fail-safe: default to basic user def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict[str, Any]: @@ -249,66 +237,53 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict } try: - from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role - from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole + from modules.datamodels.datamodelRbac import AccessRuleContext - # Get FeatureAccess for this user and instance - featureAccesses = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": userId, "featureInstanceId": instanceId} - ) + # Get FeatureAccess for this user and instance (Pydantic model) + featureAccess = rootInterface.getFeatureAccess(userId, instanceId) - logger.debug(f"_getInstancePermissions: userId={userId}, instanceId={instanceId}, featureAccesses={len(featureAccesses) if featureAccesses else 0}") + logger.debug(f"_getInstancePermissions: userId={userId}, instanceId={instanceId}, featureAccess={featureAccess is not None}") - if not featureAccesses: + if not featureAccess: logger.debug(f"_getInstancePermissions: No FeatureAccess found for user {userId} and instance {instanceId}") return permissions - # Get role IDs via FeatureAccessRole junction table - featureAccessId = featureAccesses[0].get("id") - featureAccessRoles = rootInterface.db.getRecordset( - FeatureAccessRole, - recordFilter={"featureAccessId": featureAccessId} - ) - roleIds = [far.get("roleId") for far in featureAccessRoles] + # Get role IDs via interface method + roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id)) - logger.debug(f"_getInstancePermissions: featureAccessId={featureAccessId}, roleIds={roleIds}") + logger.debug(f"_getInstancePermissions: featureAccessId={featureAccess.id}, roleIds={roleIds}") if not roleIds: - logger.debug(f"_getInstancePermissions: No roles found for FeatureAccess {featureAccessId}") + logger.debug(f"_getInstancePermissions: No roles found for FeatureAccess {featureAccess.id}") return permissions # Check if user has admin role for roleId in roleIds: - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roles: - roleLabel = roles[0].get("roleLabel", "").lower() - if "admin" in roleLabel: - permissions["isAdmin"] = True - break + role = rootInterface.getRole(roleId) + if role and "admin" in role.roleLabel.lower(): + permissions["isAdmin"] = True + break # Get permissions (AccessRules) for all roles for roleId in roleIds: - accessRules = rootInterface.db.getRecordset( - AccessRule, - recordFilter={"roleId": roleId} - ) + # Get all rules for this role (returns Pydantic models) + accessRules = rootInterface.getAccessRules(roleId=roleId) logger.debug(f"_getInstancePermissions: roleId={roleId}, accessRules={len(accessRules) if accessRules else 0}") for rule in accessRules: - context = rule.get("context", "") - item = rule.get("item", "") + context = rule.context + item = rule.item or "" # Handle DATA context (tables/fields) - if context == "DATA" or context == AccessRuleContext.DATA: + if context == AccessRuleContext.DATA or context == "DATA": if item: # Check if it's a field (table.field) or table if "." in item: tableName, fieldName = item.split(".", 1) if fieldName not in permissions["fields"]: permissions["fields"][fieldName] = {"view": False} - permissions["fields"][fieldName]["view"] = permissions["fields"][fieldName]["view"] or rule.get("view", False) + permissions["fields"][fieldName]["view"] = permissions["fields"][fieldName]["view"] or rule.view else: tableName = item if tableName not in permissions["tables"]: @@ -322,20 +297,18 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict # Merge permissions (highest wins) current = permissions["tables"][tableName] - current["view"] = current["view"] or rule.get("view", False) - current["read"] = _mergeAccessLevel(current["read"], rule.get("read") or "n") - current["create"] = _mergeAccessLevel(current["create"], rule.get("create") or "n") - current["update"] = _mergeAccessLevel(current["update"], rule.get("update") or "n") - current["delete"] = _mergeAccessLevel(current["delete"], rule.get("delete") or "n") + current["view"] = current["view"] or rule.view + current["read"] = _mergeAccessLevel(current["read"], rule.read or "n") + current["create"] = _mergeAccessLevel(current["create"], rule.create or "n") + current["update"] = _mergeAccessLevel(current["update"], rule.update or "n") + current["delete"] = _mergeAccessLevel(current["delete"], rule.delete or "n") # Handle UI context (views) - # Views are stored with full objectKey (e.g., ui.feature.trustee.dashboard) - elif context == "UI" or context == AccessRuleContext.UI: - ruleView = rule.get("view", False) + elif context == AccessRuleContext.UI or context == "UI": if item: # Store with full objectKey as per Navigation-API-Konzept - permissions["views"][item] = permissions["views"].get(item, False) or ruleView - elif ruleView: + permissions["views"][item] = permissions["views"].get(item, False) or rule.view + elif rule.view: # item=None means all views - set a wildcard flag permissions["views"]["_all"] = True @@ -343,7 +316,7 @@ def _getInstancePermissions(rootInterface, userId: str, instanceId: str) -> Dict except Exception as e: logger.debug(f"Error getting instance permissions: {e}") - return permissions + return permissions # Fail-safe: no permissions on error def _mergeAccessLevel(current: str, new: str) -> str: @@ -924,49 +897,35 @@ async def list_feature_instance_users( detail="Access denied to this feature instance" ) - # Get all FeatureAccess records for this instance - from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole - from modules.datamodels.datamodelRbac import Role - - featureAccesses = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"featureInstanceId": instanceId} - ) + # Get all FeatureAccess records for this instance (Pydantic models) + featureAccesses = rootInterface.getFeatureAccessesByInstance(instanceId) result = [] for fa in featureAccesses: - userId = fa.get("userId") - featureAccessId = fa.get("id") - - # Get user info - users = rootInterface.db.getRecordset(UserInDB, recordFilter={"id": userId}) - if not users: + # Get user info (Pydantic model) + user = rootInterface.getUser(str(fa.userId)) + if not user: continue - user = users[0] - # Get role IDs via FeatureAccessRole junction table - featureAccessRoles = rootInterface.db.getRecordset( - FeatureAccessRole, - recordFilter={"featureAccessId": featureAccessId} - ) - roleIds = [far.get("roleId") for far in featureAccessRoles] + # Get role IDs via interface method + roleIds = rootInterface.getRoleIdsForFeatureAccess(str(fa.id)) # Get role labels roleLabels = [] for roleId in roleIds: - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roles: - roleLabels.append(roles[0].get("roleLabel", "")) + role = rootInterface.getRole(roleId) + if role: + roleLabels.append(role.roleLabel) result.append(FeatureInstanceUserResponse( - id=featureAccessId, # FeatureAccess ID as primary key - userId=userId, - username=user.get("username", ""), - email=user.get("email"), - fullName=user.get("fullName"), + id=str(fa.id), # FeatureAccess ID as primary key + userId=str(fa.userId), + username=user.username, + email=user.email, + fullName=user.fullName, roleIds=roleIds, roleLabels=roleLabels, - enabled=fa.get("enabled", True) + enabled=fa.enabled )) return result @@ -1026,8 +985,8 @@ async def add_user_to_feature_instance( ) # Verify user exists - users = rootInterface.db.getRecordset(UserInDB, recordFilter={"id": data.userId}) - if not users: + user = rootInterface.getUser(data.userId) + if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"User '{data.userId}' not found" @@ -1035,10 +994,7 @@ async def add_user_to_feature_instance( # Check if user already has access from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole - existingAccess = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": data.userId, "featureInstanceId": instanceId} - ) + existingAccess = rootInterface.getFeatureAccess(data.userId, instanceId) if existingAccess: raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -1131,17 +1087,14 @@ async def remove_user_from_feature_instance( # Find FeatureAccess record from modules.datamodels.datamodelMembership import FeatureAccess - existingAccess = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": userId, "featureInstanceId": instanceId} - ) + existingAccess = rootInterface.getFeatureAccess(userId, instanceId) if not existingAccess: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User does not have access to this feature instance" ) - featureAccessId = existingAccess[0].get("id") + featureAccessId = str(existingAccess.id) # Delete FeatureAccess (CASCADE will delete FeatureAccessRole records) rootInterface.db.recordDelete(FeatureAccess, featureAccessId) @@ -1215,29 +1168,21 @@ async def update_feature_instance_user_roles( # Find FeatureAccess record from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole - existingAccess = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": userId, "featureInstanceId": instanceId} - ) + existingAccess = rootInterface.getFeatureAccess(userId, instanceId) if not existingAccess: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User does not have access to this feature instance" ) - featureAccessId = existingAccess[0].get("id") + featureAccessId = str(existingAccess.id) # Update enabled flag if provided if data.enabled is not None: rootInterface.db.recordModify(FeatureAccess, featureAccessId, {"enabled": data.enabled}) - # Delete existing FeatureAccessRole records - existingRoles = rootInterface.db.getRecordset( - FeatureAccessRole, - recordFilter={"featureAccessId": featureAccessId} - ) - for role in existingRoles: - rootInterface.db.recordDelete(FeatureAccessRole, role.get("id")) + # Delete existing FeatureAccessRole records via interface method + rootInterface.deleteFeatureAccessRoles(featureAccessId) # Create new FeatureAccessRole records for roleId in data.roleIds: @@ -1304,21 +1249,17 @@ async def get_feature_instance_available_roles( detail="Access denied to this feature instance" ) - # Get roles for this instance - from modules.datamodels.datamodelRbac import Role - instanceRoles = rootInterface.db.getRecordset( - Role, - recordFilter={"featureInstanceId": instanceId} - ) + # Get roles for this instance using interface method + instanceRoles = rootInterface.getRolesByFeatureInstance(instanceId) result = [] for role in instanceRoles: result.append({ - "id": role.get("id"), - "roleLabel": role.get("roleLabel"), - "description": role.get("description", {}), - "featureCode": role.get("featureCode"), - "isSystemRole": role.get("isSystemRole", False) + "id": role.id, + "roleLabel": role.roleLabel, + "description": role.description or {}, + "featureCode": role.featureCode, + "isSystemRole": role.isSystemRole }) return result @@ -1394,15 +1335,13 @@ def _hasMandateAdminRole(context: RequestContext) -> bool: # Check if any of the user's roles is an admin role try: rootInterface = getRootInterface() - from modules.datamodels.datamodelRbac import Role for roleId in context.roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - roleLabel = role.get("roleLabel", "") + role = rootInterface.getRole(roleId) + if role: + roleLabel = role.roleLabel # Admin role at mandate level (not feature-instance level) - if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"): + if roleLabel == "admin" and role.mandateId and not role.featureInstanceId: return True return False diff --git a/modules/routes/routeAdminRbacExport.py b/modules/routes/routeAdminRbacExport.py index 2164cb48..28caf8c8 100644 --- a/modules/routes/routeAdminRbacExport.py +++ b/modules/routes/routeAdminRbacExport.py @@ -85,34 +85,31 @@ async def export_global_rbac( try: rootInterface = getRootInterface() - # Get all global template roles (mandateId is NULL) - allRoles = rootInterface.db.getRecordset(Role) - globalRoles = [r for r in allRoles if r.get("mandateId") is None] + # Get all global template roles (mandateId is NULL) using interface method + allRoles = rootInterface.getAllRoles() + globalRoles = [r for r in allRoles if r.mandateId is None] exportRoles = [] for role in globalRoles: - roleId = role.get("id") + roleId = role.id - # Get access rules for this role - accessRules = rootInterface.db.getRecordset( - AccessRule, - recordFilter={"roleId": roleId} - ) + # Get access rules for this role using interface method + accessRules = rootInterface.getAccessRulesByRole(roleId) exportRoles.append(RoleExport( - roleLabel=role.get("roleLabel"), - description=role.get("description", {}), - featureCode=role.get("featureCode"), - isSystemRole=role.get("isSystemRole", False), + roleLabel=role.roleLabel, + description=role.description or {}, + featureCode=role.featureCode, + isSystemRole=role.isSystemRole, accessRules=[ { - "context": r.get("context"), - "item": r.get("item"), - "view": r.get("view", False), - "read": r.get("read"), - "create": r.get("create"), - "update": r.get("update"), - "delete": r.get("delete") + "context": r.context, + "item": r.item, + "view": r.view if r.view is not None else False, + "read": r.read, + "create": r.create, + "update": r.update, + "delete": r.delete } for r in accessRules ] @@ -191,21 +188,20 @@ async def import_global_rbac( result.rolesSkipped += 1 continue - # Check if role exists (global role with same label and featureCode) - existingRoles = rootInterface.db.getRecordset( - Role, - recordFilter={ - "roleLabel": roleLabel, - "mandateId": None, - "featureCode": featureCode - } - ) + # Check if role exists (global role with same label and featureCode) using interface method + allRoles = rootInterface.getAllRoles() + existingRoles = [ + r for r in allRoles + if r.roleLabel == roleLabel + and r.mandateId is None + and r.featureCode == featureCode + ] if existingRoles: if updateExisting: # Update existing role existingRole = existingRoles[0] - roleId = existingRole.get("id") + roleId = existingRole.id rootInterface.db.recordModify( Role, @@ -315,41 +311,38 @@ async def export_mandate_rbac( try: rootInterface = getRootInterface() - # Get mandate-level roles - allRoles = rootInterface.db.getRecordset(Role) + # Get mandate-level roles using interface method + allRoles = rootInterface.getAllRoles() mandateRoles = [ r for r in allRoles - if str(r.get("mandateId")) == str(context.mandateId) + if str(r.mandateId) == str(context.mandateId) ] # Filter by feature instance if not including them if not includeFeatureInstances: - mandateRoles = [r for r in mandateRoles if not r.get("featureInstanceId")] + mandateRoles = [r for r in mandateRoles if not r.featureInstanceId] exportRoles = [] for role in mandateRoles: - roleId = role.get("id") + roleId = role.id - # Get access rules for this role - accessRules = rootInterface.db.getRecordset( - AccessRule, - recordFilter={"roleId": roleId} - ) + # Get access rules for this role using interface method + accessRules = rootInterface.getAccessRulesByRole(roleId) exportRoles.append(RoleExport( - roleLabel=role.get("roleLabel"), - description=role.get("description", {}), - featureCode=role.get("featureCode"), - isSystemRole=role.get("isSystemRole", False), + roleLabel=role.roleLabel, + description=role.description or {}, + featureCode=role.featureCode, + isSystemRole=role.isSystemRole, accessRules=[ { - "context": r.get("context"), - "item": r.get("item"), - "view": r.get("view", False), - "read": r.get("read"), - "create": r.get("create"), - "update": r.get("update"), - "delete": r.get("delete") + "context": r.context, + "item": r.item, + "view": r.view if r.view is not None else False, + "read": r.read, + "create": r.create, + "update": r.update, + "delete": r.delete } for r in accessRules ] @@ -453,21 +446,20 @@ async def import_mandate_rbac( result.rolesSkipped += 1 continue - # Check if role exists (mandate role with same label) - existingRoles = rootInterface.db.getRecordset( - Role, - recordFilter={ - "roleLabel": roleLabel, - "mandateId": str(context.mandateId), - "featureInstanceId": None # Only mandate-level roles - } - ) + # Check if role exists (mandate role with same label) using interface method + allRoles = rootInterface.getAllRoles() + existingRoles = [ + r for r in allRoles + if r.roleLabel == roleLabel + and str(r.mandateId) == str(context.mandateId) + and r.featureInstanceId is None # Only mandate-level roles + ] if existingRoles: if updateExisting: # Update existing role existingRole = existingRoles[0] - roleId = existingRole.get("id") + roleId = existingRole.id rootInterface.db.recordModify( Role, @@ -556,12 +548,11 @@ def _hasMandateAdminRole(context: RequestContext) -> bool: rootInterface = getRootInterface() for roleId in context.roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - roleLabel = role.get("roleLabel", "") + role = rootInterface.getRole(roleId) + if role: + roleLabel = role.roleLabel # Admin role at mandate level - if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"): + if roleLabel == "admin" and role.mandateId and not role.featureInstanceId: return True return False @@ -580,10 +571,10 @@ def _updateAccessRules(interface, roleId: str, newRules: List[Dict[str, Any]]) - Number of rules created/updated """ try: - # Delete existing rules for this role - existingRules = interface.db.getRecordset(AccessRule, recordFilter={"roleId": roleId}) + # Delete existing rules for this role using interface method + existingRules = interface.getAccessRulesByRole(roleId) for rule in existingRules: - interface.db.recordDelete(AccessRule, rule.get("id")) + interface.db.recordDelete(AccessRule, rule.id) # Create new rules count = 0 diff --git a/modules/routes/routeAdminRbacRoles.py b/modules/routes/routeAdminRbacRoles.py index ad8a0de5..75e00cd5 100644 --- a/modules/routes/routeAdminRbacRoles.py +++ b/modules/routes/routeAdminRbacRoles.py @@ -36,25 +36,17 @@ def _getUserRoleLabels(interface, userId: str) -> List[str]: """ roleLabels: Set[str] = set() - # Get all UserMandate records for this user - userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId}) + # Get all UserMandate records for this user (Pydantic models) + userMandates = interface.getUserMandates(userId) for um in userMandates: - userMandateId = um.get("id") - if not userMandateId: - continue - - # Get all UserMandateRole records for this membership - userMandateRoles = interface.db.getRecordset( - UserMandateRole, - recordFilter={"userMandateId": str(userMandateId)} - ) + # Get all UserMandateRole records for this membership (Pydantic models) + userMandateRoles = interface.getUserMandateRoles(str(um.id)) for umr in userMandateRoles: - roleId = umr.get("roleId") - if roleId: + if umr.roleId: # Get role by ID to get roleLabel - role = interface.getRole(str(roleId)) + role = interface.getRole(str(umr.roleId)) if role: roleLabels.add(role.roleLabel) @@ -362,21 +354,13 @@ async def list_users_with_roles( try: interface = getRootInterface() - # Get all users (SysAdmin sees all) - # Use db.getRecordset with UserInDB (the actual database model) - allUsersData = interface.db.getRecordset(UserInDB) - # Convert to User objects, filtering out sensitive fields - users = [] - for u in allUsersData: - cleanedUser = {k: v for k, v in u.items() if not k.startswith("_") and k != "hashedPassword" and k != "resetToken" and k != "resetTokenExpires"} - if cleanedUser.get("roleLabels") is None: - cleanedUser["roleLabels"] = [] - users.append(User(**cleanedUser)) + # Get all users via interface method (Pydantic models) + users = interface.getAllUsers() # Filter by mandate if specified (via UserMandate table) if mandateId: - userMandates = interface.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}) - mandateUserIds = {str(um["userId"]) for um in userMandates} + userMandates = interface.getUserMandatesByMandate(mandateId) + mandateUserIds = {str(um.userId) for um in userMandates} users = [u for u in users if str(u.id) in mandateUserIds] # Filter by role if specified (via UserMandateRole) @@ -499,21 +483,18 @@ async def update_user_roles( logger.warning(f"Non-standard role label assigned: {roleLabel}") # Get user's first mandate (for role assignment) - userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId}) + userMandates = interface.getUserMandates(userId) if not userMandates: raise HTTPException( status_code=400, detail=f"User {userId} has no mandate memberships. Add to mandate first." ) - userMandateId = str(userMandates[0].get("id")) + userMandateId = str(userMandates[0].id) - # Get current roles for this mandate - existingRoles = interface.db.getRecordset( - UserMandateRole, - recordFilter={"userMandateId": userMandateId} - ) - existingRoleIds = {str(r.get("roleId")) for r in existingRoles} + # Get current roles for this mandate (Pydantic models) + existingRoles = interface.getUserMandateRoles(userMandateId) + existingRoleIds = {str(r.roleId) for r in existingRoles} # Convert roleLabels to roleIds newRoleIds = set() @@ -524,8 +505,8 @@ async def update_user_roles( # Remove roles that are no longer needed for existingRole in existingRoles: - if str(existingRole.get("roleId")) not in newRoleIds: - interface.db.recordDelete(UserMandateRole, str(existingRole.get("id"))) + if str(existingRole.roleId) not in newRoleIds: + interface.removeRoleFromUserMandate(userMandateId, str(existingRole.roleId)) # Add new roles for roleId in newRoleIds: @@ -596,25 +577,22 @@ async def add_user_role( ) # Get user's first mandate - userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId}) + userMandates = interface.getUserMandates(userId) if not userMandates: raise HTTPException( status_code=400, detail=f"User {userId} has no mandate memberships. Add to mandate first." ) - userMandateId = str(userMandates[0].get("id")) + userMandateId = str(userMandates[0].id) - # Check if role is already assigned - existingAssignment = interface.db.getRecordset( - UserMandateRole, - recordFilter={"userMandateId": userMandateId, "roleId": str(role.id)} - ) + # Check if role is already assigned - use interface method + existingRoles = interface.getUserMandateRoles(userMandateId) + roleAlreadyAssigned = any(str(r.roleId) == str(role.id) for r in existingRoles) - if not existingAssignment: - # Add the role - newRole = UserMandateRole(userMandateId=userMandateId, roleId=str(role.id)) - interface.db.recordCreate(UserMandateRole, newRole.model_dump()) + if not roleAlreadyAssigned: + # Add the role via interface method + interface.addRoleToUserMandate(userMandateId, str(role.id)) logger.info(f"Added role {roleLabel} to user {userId} by SysAdmin {currentUser.id}") userRoleLabels = _getUserRoleLabels(interface, userId) @@ -678,20 +656,14 @@ async def remove_user_role( ) # Remove role from all user's mandates - userMandates = interface.db.getRecordset(UserMandate, recordFilter={"userId": userId}) + userMandates = interface.getUserMandates(userId) roleRemoved = False for um in userMandates: - userMandateId = str(um.get("id")) + userMandateId = str(um.id) - # Find and delete the role assignment - assignments = interface.db.getRecordset( - UserMandateRole, - recordFilter={"userMandateId": userMandateId, "roleId": str(role.id)} - ) - - for assignment in assignments: - interface.db.recordDelete(UserMandateRole, str(assignment.get("id"))) + # Remove role via interface method + if interface.removeRoleFromUserMandate(userMandateId, str(role.id)): roleRemoved = True if roleRemoved: @@ -751,25 +723,21 @@ async def get_users_with_role( detail=f"Role '{roleLabel}' not found" ) - # Get all UserMandateRole assignments for this role - roleAssignments = interface.db.getRecordset( - UserMandateRole, - recordFilter={"roleId": str(role.id)} - ) + # Get all UserMandateRole assignments for this role (Pydantic models) + roleAssignments = interface.getUserMandateRolesByRole(str(role.id)) # Get unique userMandateIds - userMandateIds = {str(ra.get("userMandateId")) for ra in roleAssignments} + userMandateIds = {str(ra.userMandateId) for ra in roleAssignments} # Get userIds from UserMandate records userIds: Set[str] = set() for userMandateId in userMandateIds: - umRecords = interface.db.getRecordset(UserMandate, recordFilter={"id": userMandateId}) - if umRecords: - um = umRecords[0] + um = interface.getUserMandateById(userMandateId) + if um: # Filter by mandate if specified - if mandateId and str(um.get("mandateId")) != mandateId: + if mandateId and str(um.mandateId) != mandateId: continue - userIds.add(str(um.get("userId"))) + userIds.add(str(um.userId)) # Get users and format response result = [] diff --git a/modules/routes/routeAdminRbacRules.py b/modules/routes/routeAdminRbacRules.py index 916caf38..fc9b315e 100644 --- a/modules/routes/routeAdminRbacRules.py +++ b/modules/routes/routeAdminRbacRules.py @@ -179,17 +179,15 @@ async def get_all_permissions( # For UI/RESOURCE: Load system roles the user has across ALL their mandates # This allows users to access system UI elements without needing a specific mandate header - userMandates = rootInterface.db.getRecordset( - UserMandate, - recordFilter={"userId": str(reqContext.user.id), "enabled": True} - ) + allUserMandates = rootInterface.getUserMandates(str(reqContext.user.id)) + userMandates = [um for um in allUserMandates if um.enabled] logger.debug(f"UI/RESOURCE permissions: Found {len(userMandates)} UserMandates for user {reqContext.user.id}") # Collect all role IDs the user has across all mandates for userMandate in userMandates: - mandateRoleIds = rootInterface.getRoleIdsForUserMandate(userMandate.get("id")) - logger.debug(f"UI/RESOURCE permissions: UserMandate {userMandate.get('id')} (mandate {userMandate.get('mandateId')}) has {len(mandateRoleIds)} roles: {mandateRoleIds}") + mandateRoleIds = rootInterface.getRoleIdsForUserMandate(userMandate.id) + logger.debug(f"UI/RESOURCE permissions: UserMandate {userMandate.id} (mandate {userMandate.mandateId}) has {len(mandateRoleIds)} roles: {mandateRoleIds}") for rid in mandateRoleIds: if rid not in roleIds: roleIds.append(rid) @@ -210,14 +208,11 @@ async def get_all_permissions( allRules[ctx] = [] # Get all rules for user's roles - bypass RBAC filtering for roleId in roleIds: - ruleRecords = rootInterface.db.getRecordset( - AccessRule, - recordFilter={"roleId": str(roleId), "context": ctx.value} - ) - for ruleRecord in ruleRecords: - # Convert dict to AccessRule object - cleanedRule = {k: v for k, v in ruleRecord.items() if not k.startswith("_")} - allRules[ctx].append(AccessRule(**cleanedRule)) + # Use interface method and filter by context + rules = rootInterface.getAccessRulesByRole(str(roleId)) + for rule in rules: + if rule.context == ctx.value: + allRules[ctx].append(rule) # Build result: for each context, collect all unique items and calculate permissions for ctx in contextsToFetch: @@ -405,14 +400,8 @@ async def get_access_rules_by_role( try: interface = getRootInterface() - # Build filter for roleId - recordFilter = {"roleId": roleId} - - # Get rules from database - rules = interface.db.getRecordset(AccessRule, recordFilter=recordFilter) - - # Convert to AccessRule objects - ruleObjects = [AccessRule(**rule) for rule in rules] + # Get rules from database using interface method + ruleObjects = interface.getAccessRulesByRole(roleId) return PaginatedResponse( items=[rule.model_dump() for rule in ruleObjects], @@ -1128,13 +1117,9 @@ async def getCatalogObjects( if mandateId: try: interface = getRootInterface() - # Get all feature instances for this mandate - from modules.datamodels.datamodelFeatures import FeatureInstance - instances = interface.db.getRecordset( - FeatureInstance, - recordFilter={"mandateId": mandateId, "enabled": True} - ) - activeFeatures = set(inst.get("featureCode") for inst in instances) + # Get all feature instances for this mandate using interface method + instances = interface.getFeatureInstancesByMandate(mandateId, enabledOnly=True) + activeFeatures = set(inst.featureCode for inst in instances) # Always include "system" feature activeFeatures.add("system") except Exception as e: diff --git a/modules/routes/routeAdminUserAccessOverview.py b/modules/routes/routeAdminUserAccessOverview.py index f12fe2b6..372e2193 100644 --- a/modules/routes/routeAdminUserAccessOverview.py +++ b/modules/routes/routeAdminUserAccessOverview.py @@ -47,11 +47,15 @@ def _getAccessLevelLabel(level: Optional[str]) -> str: return labels.get(level, "-") -def _getRoleScope(role: Dict[str, Any]) -> str: - """Determine the scope of a role.""" - if role.get("featureInstanceId"): +def _getRoleScope(role) -> str: + """Determine the scope of a role. Accepts Role object or dict.""" + # Support both Pydantic models and dicts + featureInstanceId = getattr(role, 'featureInstanceId', None) or (role.get("featureInstanceId") if isinstance(role, dict) else None) + mandateId = getattr(role, 'mandateId', None) or (role.get("mandateId") if isinstance(role, dict) else None) + + if featureInstanceId: return "instance" - elif role.get("mandateId"): + elif mandateId: return "mandate" else: return "global" @@ -79,18 +83,18 @@ async def listUsersForOverview( try: interface = getRootInterface() - # Get all users - allUsersData = interface.db.getRecordset(UserInDB) + # Get all users using interface method + allUsers = interface.getAllUsers() result = [] - for u in allUsersData: + for u in allUsers: result.append({ - "id": u.get("id"), - "username": u.get("username"), - "email": u.get("email"), - "fullName": u.get("fullName"), - "isSysAdmin": u.get("isSysAdmin", False), - "enabled": u.get("enabled", True), + "id": u.id, + "username": u.username, + "email": u.email, + "fullName": u.fullName, + "isSysAdmin": u.isSysAdmin, + "enabled": u.enabled, }) # Sort by username @@ -172,47 +176,43 @@ async def getUserAccessOverview( allRoles = [] roleIdToInfo = {} # Map roleId to role info for later reference - # Get mandates for this user - mandateFilter = {"userId": userId, "enabled": True} + # Get mandates for this user using interface method + allUserMandates = interface.getUserMandates(userId) + # Filter by enabled and optionally mandateId + userMandates = [um for um in allUserMandates if um.enabled] if mandateId: - mandateFilter["mandateId"] = mandateId - - userMandates = interface.db.getRecordset(UserMandate, recordFilter=mandateFilter) + userMandates = [um for um in userMandates if um.mandateId == mandateId] mandatesInfo = [] for um in userMandates: - umId = um.get("id") - umMandateId = um.get("mandateId") + umId = um.id + umMandateId = um.mandateId # Get mandate name mandate = interface.getMandate(umMandateId) mandateName = mandate.name if mandate else umMandateId - # Get roles for this UserMandate - umRoles = interface.db.getRecordset( - UserMandateRole, - recordFilter={"userMandateId": umId} - ) + # Get roles for this UserMandate using interface method + umRoles = interface.getUserMandateRoles(umId) mandateRoleIds = [] for umr in umRoles: - roleId = umr.get("roleId") + roleId = umr.roleId if roleId: mandateRoleIds.append(roleId) - # Get role details - roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] + # Get role details using interface method + role = interface.getRole(roleId) + if role: scope = _getRoleScope(role) roleInfo = { "id": roleId, - "roleLabel": role.get("roleLabel"), - "description": role.get("description", {}), + "roleLabel": role.roleLabel, + "description": role.description or {}, "scope": scope, "scopePriority": _getRoleScopePriority(scope), - "mandateId": role.get("mandateId"), - "featureInstanceId": role.get("featureInstanceId"), + "mandateId": role.mandateId, + "featureInstanceId": role.featureInstanceId, "source": "mandate", "sourceMandateId": umMandateId, "sourceMandateName": mandateName, @@ -220,69 +220,59 @@ async def getUserAccessOverview( allRoles.append(roleInfo) roleIdToInfo[roleId] = roleInfo - # Get feature instances for this mandate - featureInstanceFilter = {"userId": userId, "enabled": True} - featureAccesses = interface.db.getRecordset(FeatureAccess, recordFilter=featureInstanceFilter) + # Get feature instances for this mandate using interface method + allFeatureAccesses = interface.getFeatureAccessesForUser(userId) + featureAccesses = [fa for fa in allFeatureAccesses if fa.enabled] featureInstancesInfo = [] for fa in featureAccesses: - faId = fa.get("id") - faInstanceId = fa.get("featureInstanceId") + faId = fa.id + faInstanceId = fa.featureInstanceId - # Check if instance belongs to this mandate - instance = interface.db.getRecordset(FeatureInstance, recordFilter={"id": faInstanceId}) + # Check if instance belongs to this mandate using interface method + instance = interface.getFeatureInstance(faInstanceId) if not instance: continue - instance = instance[0] - if instance.get("mandateId") != umMandateId: + if instance.mandateId != umMandateId: continue # Filter by featureInstanceId if specified if featureInstanceId and faInstanceId != featureInstanceId: continue - # Get feature info - featureCode = instance.get("featureCode") - featureRecords = interface.db.getRecordset(Feature, recordFilter={"code": featureCode}) - featureLabel = featureRecords[0].get("label", {}) if featureRecords else {} + # Get feature info using interface method + featureCode = instance.featureCode + feature = interface.getFeatureByCode(featureCode) + featureLabel = feature.label if feature else {} - # Get roles for this FeatureAccess - faRoles = interface.db.getRecordset( - FeatureAccessRole, - recordFilter={"featureAccessId": faId} - ) + # Get roles for this FeatureAccess using interface method + instanceRoleIds = interface.getRoleIdsForFeatureAccess(faId) - instanceRoleIds = [] - for far in faRoles: - roleId = far.get("roleId") - if roleId: - instanceRoleIds.append(roleId) - - # Get role details (if not already added) - if roleId not in roleIdToInfo: - roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - scope = _getRoleScope(role) - roleInfo = { - "id": roleId, - "roleLabel": role.get("roleLabel"), - "description": role.get("description", {}), - "scope": scope, - "scopePriority": _getRoleScopePriority(scope), - "mandateId": role.get("mandateId"), - "featureInstanceId": role.get("featureInstanceId"), - "source": "featureInstance", - "sourceInstanceId": faInstanceId, - "sourceInstanceLabel": instance.get("label"), - } - allRoles.append(roleInfo) - roleIdToInfo[roleId] = roleInfo + for roleId in instanceRoleIds: + # Get role details (if not already added) + if roleId not in roleIdToInfo: + role = interface.getRole(roleId) + if role: + scope = _getRoleScope(role) + roleInfo = { + "id": roleId, + "roleLabel": role.roleLabel, + "description": role.description or {}, + "scope": scope, + "scopePriority": _getRoleScopePriority(scope), + "mandateId": role.mandateId, + "featureInstanceId": role.featureInstanceId, + "source": "featureInstance", + "sourceInstanceId": faInstanceId, + "sourceInstanceLabel": instance.label, + } + allRoles.append(roleInfo) + roleIdToInfo[roleId] = roleInfo featureInstancesInfo.append({ "id": faInstanceId, - "label": instance.get("label"), + "label": instance.label, "featureCode": featureCode, "featureLabel": featureLabel, "roleIds": instanceRoleIds, @@ -317,12 +307,12 @@ async def getUserAccessOverview( roleLabel = roleInfo.get("roleLabel", "unknown") roleScope = roleInfo.get("scope", "unknown") - # Get all rules for this role - rules = interface.db.getRecordset(AccessRule, recordFilter={"roleId": roleId}) + # Get all rules for this role using interface method + rules = interface.getAccessRulesByRole(roleId) for rule in rules: - context = rule.get("context") - item = rule.get("item") + context = rule.context + item = rule.item accessEntry = { "item": item or "(all)", @@ -333,20 +323,20 @@ async def getUserAccessOverview( } if context == "UI": - accessEntry["view"] = rule.get("view", False) + accessEntry["view"] = rule.view if rule.view is not None else False if accessEntry["view"]: uiAccess.append(accessEntry) elif context == "DATA": - accessEntry["view"] = rule.get("view", False) - accessEntry["read"] = _getAccessLevelLabel(rule.get("read")) - accessEntry["create"] = _getAccessLevelLabel(rule.get("create")) - accessEntry["update"] = _getAccessLevelLabel(rule.get("update")) - accessEntry["delete"] = _getAccessLevelLabel(rule.get("delete")) + accessEntry["view"] = rule.view if rule.view is not None else False + accessEntry["read"] = _getAccessLevelLabel(rule.read) + accessEntry["create"] = _getAccessLevelLabel(rule.create) + accessEntry["update"] = _getAccessLevelLabel(rule.update) + accessEntry["delete"] = _getAccessLevelLabel(rule.delete) dataAccess.append(accessEntry) elif context == "RESOURCE": - accessEntry["view"] = rule.get("view", False) + accessEntry["view"] = rule.view if rule.view is not None else False if accessEntry["view"]: resourceAccess.append(accessEntry) diff --git a/modules/routes/routeChat.py b/modules/routes/routeChat.py deleted file mode 100644 index 137b4a99..00000000 --- a/modules/routes/routeChat.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) 2025 Patrick Motsch -# All rights reserved. -""" -Chat Playground routes for the backend API. -Implements the endpoints for chat playground workflow management. -""" - -import logging -from typing import Optional, Dict, Any -from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request - -# Import auth modules -from modules.auth import limiter, getRequestContext, RequestContext - -# Import interfaces -from modules.interfaces import interfaceDbChat - -# Import models -from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum - -# Import workflow control functions -from modules.workflows.automation import chatStart, chatStop - -# Configure logger -logger = logging.getLogger(__name__) - -# Create router for chat playground endpoints -router = APIRouter( - prefix="/api/chat/playground", - tags=["Chat Playground"], - responses={404: {"description": "Not found"}} -) - -def _getServiceChat(context: RequestContext): - return interfaceDbChat.getInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) - -# Workflow start endpoint -@router.post("/start", response_model=ChatWorkflow) -@limiter.limit("120/minute") -async def start_workflow( - request: Request, - workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"), - workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Dynamic' or 'Automation' (mandatory)"), - userInput: UserInputRequest = Body(...), - context: RequestContext = Depends(getRequestContext) -) -> ChatWorkflow: - """ - Starts a new workflow or continues an existing one. - Corresponds to State 1 in the state machine documentation. - - Args: - workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution - """ - try: - # Start or continue workflow using playground controller - mandateId = str(context.mandateId) if context.mandateId else None - workflow = await chatStart(context.user, userInput, workflowMode, workflowId, mandateId=mandateId) - - return workflow - - except Exception as e: - logger.error(f"Error in start_workflow: {str(e)}") - raise HTTPException( - status_code=500, - detail=str(e) - ) - -# State 8: Workflow Stopped endpoint -@router.post("/{workflowId}/stop", response_model=ChatWorkflow) -@limiter.limit("120/minute") -async def stop_workflow( - request: Request, - workflowId: str = Path(..., description="ID of the workflow to stop"), - context: RequestContext = Depends(getRequestContext) -) -> ChatWorkflow: - """Stops a running workflow.""" - try: - # Stop workflow using playground controller - mandateId = str(context.mandateId) if context.mandateId else None - workflow = await chatStop(context.user, workflowId, mandateId=mandateId) - - return workflow - - except Exception as e: - logger.error(f"Error in stop_workflow: {str(e)}") - raise HTTPException( - status_code=500, - detail=str(e) - ) - -# Unified Chat Data Endpoint for Polling -@router.get("/{workflowId}/chatData") -@limiter.limit("120/minute") -async def get_workflow_chat_data( - request: Request, - workflowId: str = Path(..., description="ID of the workflow"), - afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"), - context: RequestContext = Depends(getRequestContext) -) -> Dict[str, Any]: - """ - Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer. - Returns all data types in chronological order based on _createdAt timestamp. - """ - try: - # Get service center - interfaceDbChat = _getServiceChat(context) - - # Verify workflow exists - workflow = interfaceDbChat.getWorkflow(workflowId) - if not workflow: - raise HTTPException( - status_code=404, - detail=f"Workflow with ID {workflowId} not found" - ) - - # Get unified chat data using the new method - chatData = interfaceDbChat.getUnifiedChatData(workflowId, afterTimestamp) - - return chatData - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True) - raise HTTPException( - status_code=500, - detail=f"Error getting unified chat data: {str(e)}" - ) diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py index 5d84efd9..95bbd014 100644 --- a/modules/routes/routeDataConnections.py +++ b/modules/routes/routeDataConnections.py @@ -43,30 +43,14 @@ def getTokenStatusForConnection(interface, connectionId: str) -> tuple[str, Opti - tokenExpiresAt: UTC timestamp or None """ try: - # Query tokens table for the latest token for this connection - tokens = interface.db.getRecordset( - Token, - recordFilter={"connectionId": connectionId} - ) - - if not tokens: - return "none", None - - # Find the most recent token (highest createdAt timestamp) - latestToken = None - latestCreatedAt = 0 - - for tokenData in tokens: - createdAt = parseTimestamp(tokenData.get("createdAt"), default=0) - if createdAt > latestCreatedAt: - latestCreatedAt = createdAt - latestToken = tokenData + # Query tokens table for the latest token for this connection using interface method + latestToken = interface.getConnectionToken(connectionId) if not latestToken: return "none", None # Check if token is expired - expiresAt = parseTimestamp(latestToken.get("expiresAt")) + expiresAt = parseTimestamp(latestToken.expiresAt) if not expiresAt: return "none", None diff --git a/modules/routes/routeDataMandates.py b/modules/routes/routeDataMandates.py index 23358947..38877a9f 100644 --- a/modules/routes/routeDataMandates.py +++ b/modules/routes/routeDataMandates.py @@ -291,9 +291,9 @@ async def delete_mandate( ) # MULTI-TENANT: Delete all UserMandate entries for this mandate first - userMandates = appInterface.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}) + userMandates = appInterface.getUserMandatesByMandate(mandateId) for um in userMandates: - appInterface.db.deleteRecord(UserMandate, um["id"]) + appInterface.deleteUserMandate(str(um.userId), mandateId) logger.info(f"Deleted {len(userMandates)} UserMandate entries for mandate {mandateId}") # Delete mandate @@ -377,39 +377,46 @@ async def list_mandate_users( ) # Get all UserMandate entries for this mandate - userMandates = rootInterface.db.getRecordset( - UserMandate, - recordFilter={"mandateId": targetMandateId} - ) + userMandates = rootInterface.getUserMandatesByMandate(targetMandateId) result = [] for um in userMandates: # Get user info - user = rootInterface.getUser(um.get("userId")) + user = rootInterface.getUser(str(um.userId)) if not user: continue # Get roles for this membership - roleIds = rootInterface.getRoleIdsForUserMandate(um.get("id")) + roleIds = rootInterface.getRoleIdsForUserMandate(str(um.id)) - # Resolve role labels for display + # Resolve role labels for display (only mandate-level roles, deduplicated) roleLabels = [] + filteredRoleIds = [] + seenLabels = set() for roleId in roleIds: role = rootInterface.getRole(roleId) if role: - roleLabels.append(role.roleLabel) + # Skip feature-instance roles - they don't belong in mandate membership + if role.featureInstanceId: + continue + filteredRoleIds.append(roleId) + if role.roleLabel not in seenLabels: + roleLabels.append(role.roleLabel) + seenLabels.add(role.roleLabel) else: - roleLabels.append(roleId) # Fallback to ID if not found + # Role not found - fail-safe: skip (no access) + logger.warning(f"Role {roleId} not found, skipping") + continue result.append({ - "id": um.get("id"), # UserMandate ID as primary key + "id": str(um.id), # UserMandate ID as primary key "userId": str(user.id), "username": user.username, "email": user.email, "fullName": user.fullName, - "roleIds": roleIds, + "roleIds": filteredRoleIds, "roleLabels": roleLabels, - "enabled": um.get("enabled", True) + "enabled": um.enabled }) # Apply search, filtering, and sorting if pagination requested @@ -545,18 +552,12 @@ async def add_user_to_mandate( # 6. Validate roles (must exist and belong to this mandate or be global) for roleId in data.roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if not roleRecords: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Role {roleId} not found" - ) - role = roleRecords[0] - roleMandateId = role.get("mandateId") - if roleMandateId and str(roleMandateId) != str(targetMandateId): + try: + rootInterface.validateRoleForMandate(roleId, targetMandateId) + except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Role {roleId} belongs to a different mandate" + detail=str(e) ) # 7. Create UserMandate @@ -718,18 +719,12 @@ async def update_user_roles_in_mandate( # Validate new roles for roleId in roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if not roleRecords: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Role {roleId} not found" - ) - role = roleRecords[0] - roleMandateId = role.get("mandateId") - if roleMandateId and str(roleMandateId) != str(targetMandateId): + try: + rootInterface.validateRoleForMandate(roleId, targetMandateId) + except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Role {roleId} belongs to a different mandate" + detail=str(e) ) # Check if removing admin role would leave mandate without admins @@ -745,12 +740,7 @@ async def update_user_roles_in_mandate( ) # Remove existing role assignments - existingRoles = rootInterface.db.getRecordset( - UserMandateRole, - recordFilter={"userMandateId": str(membership.id)} - ) - for er in existingRoles: - rootInterface.db.recordDelete(UserMandateRole, er.get("id")) + rootInterface.deleteUserMandateRoles(str(membership.id)) # Add new role assignments for roleId in roleIds: @@ -812,19 +802,17 @@ def _hasMandateAdminRole(context: RequestContext, mandateId: str) -> bool: rootInterface = interfaceDbApp.getRootInterface() for roleId in context.roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - roleLabel = role.get("roleLabel", "") + role = rootInterface.getRole(roleId) + if role: # Admin role at mandate level (not feature-instance level) - if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"): + if role.roleLabel == "admin" and role.mandateId and not role.featureInstanceId: return True return False except Exception as e: logger.error(f"Error checking mandate admin role: {e}") - return False + return False # Fail-safe: no access on error def _isLastMandateAdmin(interface, mandateId: str, excludeUserId: str) -> bool: @@ -832,19 +820,17 @@ def _isLastMandateAdmin(interface, mandateId: str, excludeUserId: str) -> bool: Check if excluding this user would leave the mandate without any admins. """ try: - # Get all UserMandates for this mandate - userMandates = interface.db.getRecordset( - UserMandate, - recordFilter={"mandateId": mandateId, "enabled": True} - ) + # Get all UserMandates for this mandate (Pydantic models) + allMandates = interface.getUserMandatesByMandate(mandateId) + userMandates = [um for um in allMandates if um.enabled] adminCount = 0 for um in userMandates: - if str(um.get("userId")) == str(excludeUserId): + if str(um.userId) == str(excludeUserId): continue # Check if this user has admin role - roleIds = interface.getRoleIdsForUserMandate(um.get("id")) + roleIds = interface.getRoleIdsForUserMandate(str(um.id)) if _hasAdminRoleInList(interface, roleIds, mandateId): adminCount += 1 @@ -852,7 +838,7 @@ def _isLastMandateAdmin(interface, mandateId: str, excludeUserId: str) -> bool: except Exception as e: logger.error(f"Error checking last admin: {e}") - return True # Fail-safe: assume they're the last admin + return True # Fail-safe: assume they're the last admin (prevents deletion) def _hasAdminRoleInList(interface, roleIds: List[str], mandateId: str) -> bool: @@ -860,13 +846,10 @@ def _hasAdminRoleInList(interface, roleIds: List[str], mandateId: str) -> bool: Check if any of the role IDs is an admin role for the mandate. """ for roleId in roleIds: - roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - roleLabel = role.get("roleLabel", "") - roleMandateId = role.get("mandateId") - # Admin role at mandate level - if roleLabel == "admin" and (not roleMandateId or str(roleMandateId) == str(mandateId)): - if not role.get("featureInstanceId"): + role = interface.getRole(roleId) + if role: + # Admin role at mandate level (global or mandate-specific, not feature-instance) + if role.roleLabel == "admin" and not role.featureInstanceId: + if not role.mandateId or str(role.mandateId) == str(mandateId): return True return False diff --git a/modules/routes/routeDataUsers.py b/modules/routes/routeDataUsers.py index f963a33c..5e78d12a 100644 --- a/modules/routes/routeDataUsers.py +++ b/modules/routes/routeDataUsers.py @@ -21,7 +21,8 @@ import modules.interfaces.interfaceDbApp as interfaceDbApp from modules.auth import limiter, getRequestContext, RequestContext # Import the attribute definition and helper functions -from modules.datamodels.datamodelUam import User, UserInDB +from modules.datamodels.datamodelUam import User, UserInDB, AuthAuthority +from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict # Configure logger @@ -251,16 +252,10 @@ async def get_users( ) elif context.isSysAdmin: # SysAdmin without mandateId sees all users - # Get all users directly from database using UserInDB (the actual database model) - allUsers = appInterface.db.getRecordset(UserInDB) - # Convert to cleaned dictionaries first for filtering - cleanedUsers = [] - for u in allUsers: - cleanedUser = {k: v for k, v in u.items() if not k.startswith("_") and k != "hashedPassword" and k != "resetToken" and k != "resetTokenExpires"} - # Ensure roleLabels is always a list - if cleanedUser.get("roleLabels") is None: - cleanedUser["roleLabels"] = [] - cleanedUsers.append(cleanedUser) + # Get all users via interface method (returns Pydantic User models) + allUserModels = appInterface.getAllUsers() + # Convert to dictionaries for filtering/sorting + cleanedUsers = [u.model_dump() for u in allUserModels] # Apply server-side filtering and sorting filteredUsers = _applyFiltersAndSort(cleanedUsers, paginationParams) @@ -331,11 +326,7 @@ async def get_user( # MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin) if context.mandateId and not context.isSysAdmin: - from modules.datamodels.datamodelMembership import UserMandate - userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={ - "userId": userId, - "mandateId": str(context.mandateId) - }) + userMandate = appInterface.getUserMandate(userId, str(context.mandateId)) if not userMandate: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -427,11 +418,7 @@ async def update_user( # MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin) if context.mandateId and not context.isSysAdmin: - from modules.datamodels.datamodelMembership import UserMandate - userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={ - "userId": userId, - "mandateId": str(context.mandateId) - }) + userMandate = appInterface.getUserMandate(userId, str(context.mandateId)) if not userMandate: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -482,11 +469,7 @@ async def reset_user_password( # MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin) if context.mandateId and not context.isSysAdmin: - from modules.datamodels.datamodelMembership import UserMandate - userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={ - "userId": userId, - "mandateId": str(context.mandateId) - }) + userMandate = appInterface.getUserMandate(userId, str(context.mandateId)) if not userMandate: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -664,11 +647,7 @@ async def send_password_link( # MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin) if context.mandateId and not context.isSysAdmin: - from modules.datamodels.datamodelMembership import UserMandate - userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={ - "userId": userId, - "mandateId": str(context.mandateId) - }) + userMandate = appInterface.getUserMandate(userId, str(context.mandateId)) if not userMandate: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -791,11 +770,7 @@ async def delete_user( # MULTI-TENANT: Verify user is in the same mandate (unless SysAdmin) if context.mandateId and not context.isSysAdmin: - from modules.datamodels.datamodelMembership import UserMandate - userMandate = appInterface.db.getRecordset(UserMandate, recordFilter={ - "userId": userId, - "mandateId": str(context.mandateId) - }) + userMandate = appInterface.getUserMandate(userId, str(context.mandateId)) if not userMandate: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -803,10 +778,9 @@ async def delete_user( ) # Delete UserMandate entries for this user first - from modules.datamodels.datamodelMembership import UserMandate - userMandates = appInterface.db.getRecordset(UserMandate, recordFilter={"userId": userId}) + userMandates = appInterface.getUserMandates(userId) for um in userMandates: - appInterface.db.deleteRecord(UserMandate, um["id"]) + appInterface.deleteUserMandate(userId, str(um.mandateId)) success = appInterface.deleteUser(userId) if not success: diff --git a/modules/routes/routeDataWorkflows.py b/modules/routes/routeDataWorkflows.py index 799a9855..80ca5986 100644 --- a/modules/routes/routeDataWorkflows.py +++ b/modules/routes/routeDataWorkflows.py @@ -163,16 +163,14 @@ async def update_workflow( # Get workflow interface with current user context workflowInterface = getInterface(currentUser) - # Get raw workflow data from database to check permissions - workflows = workflowInterface.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) - if not workflows: + # Get workflow using interface method to check permissions + workflow = workflowInterface.getWorkflow(workflowId) + if not workflow: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found" ) - workflow_data = workflows[0] - # Check if user has permission to update using RBAC if not workflowInterface.checkRbacPermission(ChatWorkflow, "update", workflowId): raise HTTPException( @@ -230,6 +228,49 @@ async def get_workflow_status( detail=f"Error getting workflow status: {str(e)}" ) + +# API Endpoint for stopping a workflow +@router.post("/{workflowId}/stop", response_model=ChatWorkflow) +@limiter.limit("120/minute") +async def stop_workflow( + request: Request, + workflowId: str = Path(..., description="ID of the workflow to stop"), + currentUser: User = Depends(getCurrentUser) +) -> ChatWorkflow: + """ + Stop a running workflow. + This is a general endpoint that can be used by any feature to stop a workflow. + """ + try: + from modules.workflows.automation import chatStop + + # Get the workflow first to get mandateId + interfaceChatDb = getServiceChat(currentUser) + workflow = interfaceChatDb.getWorkflow(workflowId) + + if not workflow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workflow with ID {workflowId} not found" + ) + + mandateId = workflow.get("mandateId") if isinstance(workflow, dict) else getattr(workflow, "mandateId", None) + + # Stop the workflow + stoppedWorkflow = await chatStop(currentUser, workflowId, mandateId=mandateId) + + return stoppedWorkflow + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error stopping workflow: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error stopping workflow: {str(e)}" + ) + + # API Endpoint for workflow logs with selective data transfer @router.get("/{workflowId}/logs", response_model=PaginatedResponse[ChatLog]) @limiter.limit("120/minute") diff --git a/modules/routes/routeGdpr.py b/modules/routes/routeGdpr.py index 3f06810f..af0c7199 100644 --- a/modules/routes/routeGdpr.py +++ b/modules/routes/routeGdpr.py @@ -109,96 +109,73 @@ async def export_user_data( "authenticationAuthority": str(getattr(currentUser, "authenticationAuthority", "")) } - # Mandate memberships - from modules.datamodels.datamodelMembership import UserMandate - userMandates = rootInterface.db.getRecordset( - UserMandate, - recordFilter={"userId": str(currentUser.id)} - ) + # Mandate memberships using interface method + userMandates = rootInterface.getUserMandates(str(currentUser.id)) mandates = [] for um in userMandates: - mandateId = um.get("mandateId") + mandateId = um.mandateId - # Get mandate details - mandateRecords = rootInterface.db.getRecordset( - Mandate, - recordFilter={"id": mandateId} - ) - mandateName = mandateRecords[0].get("name") if mandateRecords else "Unknown" + # Get mandate details using interface method + mandate = rootInterface.getMandate(mandateId) + mandateName = mandate.name if mandate else "Unknown" # Get roles for this membership - roleIds = rootInterface.getRoleIdsForUserMandate(um.get("id")) + roleIds = rootInterface.getRoleIdsForUserMandate(um.id) mandates.append({ - "userMandateId": um.get("id"), + "userMandateId": um.id, "mandateId": mandateId, "mandateName": mandateName, - "enabled": um.get("enabled", True), + "enabled": um.enabled, "roleIds": roleIds, - "joinedAt": um.get("createdAt") + "joinedAt": um.createdAt }) - # Feature access records - from modules.datamodels.datamodelMembership import FeatureAccess - featureAccesses = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": str(currentUser.id)} - ) + # Feature access records using interface method + featureAccesses = rootInterface.getFeatureAccessesForUser(str(currentUser.id)) featureAccessList = [] for fa in featureAccesses: - instanceId = fa.get("featureInstanceId") + instanceId = fa.featureInstanceId - # Get instance details - from modules.datamodels.datamodelFeatures import FeatureInstance - instanceRecords = rootInterface.db.getRecordset( - FeatureInstance, - recordFilter={"id": instanceId} - ) + # Get instance details using interface method + instance = rootInterface.getFeatureInstance(instanceId) - instanceInfo = instanceRecords[0] if instanceRecords else {} - roleIds = rootInterface.getRoleIdsForFeatureAccess(fa.get("id")) + roleIds = rootInterface.getRoleIdsForFeatureAccess(fa.id) featureAccessList.append({ - "featureAccessId": fa.get("id"), + "featureAccessId": fa.id, "featureInstanceId": instanceId, - "featureCode": instanceInfo.get("featureCode"), - "instanceLabel": instanceInfo.get("label"), - "enabled": fa.get("enabled", True), + "featureCode": instance.featureCode if instance else None, + "instanceLabel": instance.label if instance else None, + "enabled": fa.enabled, "roleIds": roleIds }) - # Invitations created by user - from modules.datamodels.datamodelInvitation import Invitation - invitationsCreated = rootInterface.db.getRecordset( - Invitation, - recordFilter={"createdBy": str(currentUser.id)} - ) + # Invitations created by user using interface method + invitationsCreated = rootInterface.getInvitationsByCreator(str(currentUser.id)) invitationsCreatedList = [ { - "id": inv.get("id"), - "mandateId": inv.get("mandateId"), - "createdAt": inv.get("createdAt"), - "expiresAt": inv.get("expiresAt"), - "maxUses": inv.get("maxUses"), - "currentUses": inv.get("currentUses") + "id": inv.id, + "mandateId": inv.mandateId, + "createdAt": inv.createdAt, + "expiresAt": inv.expiresAt, + "maxUses": inv.maxUses, + "currentUses": inv.currentUses } for inv in invitationsCreated ] - # Invitations used by user - invitationsUsed = rootInterface.db.getRecordset( - Invitation, - recordFilter={"usedBy": str(currentUser.id)} - ) + # Invitations used by user using interface method + invitationsUsed = rootInterface.getInvitationsByUsedBy(str(currentUser.id)) invitationsUsedList = [ { - "id": inv.get("id"), - "mandateId": inv.get("mandateId"), - "usedAt": inv.get("usedAt") + "id": inv.id, + "mandateId": inv.mandateId, + "usedAt": inv.usedAt } for inv in invitationsUsed ] @@ -262,26 +239,18 @@ async def export_portable_data( "additionalProperty": [] } - # Add mandate memberships as organization affiliations - from modules.datamodels.datamodelMembership import UserMandate - userMandates = rootInterface.db.getRecordset( - UserMandate, - recordFilter={"userId": str(currentUser.id)} - ) + # Add mandate memberships as organization affiliations using interface method + userMandates = rootInterface.getUserMandates(str(currentUser.id)) affiliations = [] for um in userMandates: - mandateRecords = rootInterface.db.getRecordset( - Mandate, - recordFilter={"id": um.get("mandateId")} - ) - if mandateRecords: - mandate = mandateRecords[0] + mandate = rootInterface.getMandate(um.mandateId) + if mandate: affiliations.append({ "@type": "Organization", - "identifier": um.get("mandateId"), - "name": mandate.get("name"), - "membershipActive": um.get("enabled", True) + "identifier": um.mandateId, + "name": mandate.name, + "membershipActive": um.enabled }) if affiliations: @@ -370,15 +339,12 @@ async def delete_account( # Step 2: Revoke invitations BEFORE generic deletion (business logic) rootInterface = getRootInterface() from modules.datamodels.datamodelInvitation import Invitation - userInvitations = rootInterface.db.getRecordset( - Invitation, - recordFilter={"createdBy": str(currentUser.id)} - ) + userInvitations = rootInterface.getInvitationsByCreator(str(currentUser.id)) for inv in userInvitations: rootInterface.db.recordModify( Invitation, - inv.get("id"), + inv.id, {"revokedAt": getUtcTimestamp()} ) diff --git a/modules/routes/routeInvitations.py b/modules/routes/routeInvitations.py index 2196bd73..6a53fb38 100644 --- a/modules/routes/routeInvitations.py +++ b/modules/routes/routeInvitations.py @@ -131,17 +131,14 @@ async def create_invitation( # Validate role IDs exist and belong to this mandate or are global for roleId in data.roleIds: - from modules.datamodels.datamodelRbac import Role - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if not roleRecords: + role = rootInterface.getRole(roleId) + if not role: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Role '{roleId}' not found" ) - role = roleRecords[0] # Role must be global or belong to this mandate - roleMandateId = role.get("mandateId") - if roleMandateId and str(roleMandateId) != str(context.mandateId): + if role.mandateId and str(role.mandateId) != str(context.mandateId): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Role '{roleId}' belongs to a different mandate" @@ -149,18 +146,13 @@ async def create_invitation( # Validate feature instance if provided if data.featureInstanceId: - from modules.datamodels.datamodelFeatures import FeatureInstance - instanceRecords = rootInterface.db.getRecordset( - FeatureInstance, - recordFilter={"id": data.featureInstanceId} - ) - if not instanceRecords: + instance = rootInterface.getFeatureInstance(data.featureInstanceId) + if not instance: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Feature instance '{data.featureInstanceId}' not found" ) - instance = instanceRecords[0] - if str(instance.get("mandateId")) != str(context.mandateId): + if str(instance.mandateId) != str(context.mandateId): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Feature instance belongs to a different mandate" @@ -196,14 +188,9 @@ async def create_invitation( if data.email: try: from modules.connectors.connectorMessagingEmail import ConnectorMessagingEmail - from modules.datamodels.datamodelUam import Mandate - # Get mandate name for the email - mandateRecords = rootInterface.db.getRecordset( - Mandate, - recordFilter={"id": str(context.mandateId)} - ) - mandateName = mandateRecords[0].get("name", "PowerOn") if mandateRecords else "PowerOn" + mandate = rootInterface.getMandate(str(context.mandateId)) + mandateName = mandate.name if mandate else "PowerOn" emailConnector = ConnectorMessagingEmail() emailSubject = f"Einladung zu {mandateName}" @@ -259,14 +246,10 @@ async def create_invitation( existingUser = rootInterface.getUserByUsername(data.targetUsername) if existingUser: from modules.routes.routeNotifications import createInvitationNotification - from modules.datamodels.datamodelUam import Mandate # Get mandate name for notification - mandateRecords = rootInterface.db.getRecordset( - Mandate, - recordFilter={"id": str(context.mandateId)} - ) - mandateName = mandateRecords[0].get("mandateLabel", "PowerOn") if mandateRecords else "PowerOn" + mandate = rootInterface.getMandate(str(context.mandateId)) + mandateName = mandate.mandateLabel if mandate and mandate.mandateLabel else "PowerOn" inviterName = context.user.fullName or context.user.username createInvitationNotification( @@ -348,38 +331,38 @@ async def list_invitations( try: rootInterface = getRootInterface() - # Get all invitations for this mandate - allInvitations = rootInterface.db.getRecordset( - Invitation, - recordFilter={"mandateId": str(context.mandateId)} - ) + # Get all invitations for this mandate (Pydantic models) + allInvitations = rootInterface.getInvitationsByMandate(str(context.mandateId)) currentTime = getUtcTimestamp() result = [] for inv in allInvitations: # Skip revoked invitations - if inv.get("revokedAt"): + if inv.revokedAt: continue # Filter by usage - if not includeUsed and inv.get("currentUses", 0) >= inv.get("maxUses", 1): + currentUses = inv.currentUses or 0 + maxUses = inv.maxUses or 1 + if not includeUsed and currentUses >= maxUses: continue # Filter by expiration - if not includeExpired and inv.get("expiresAt", 0) < currentTime: + expiresAt = inv.expiresAt or 0 + if not includeExpired and expiresAt < currentTime: continue # Build invite URL from modules.shared.configuration import APP_CONFIG frontendUrl = APP_CONFIG.get("APP_FRONTEND_URL", "http://localhost:8080") - inviteUrl = f"{frontendUrl}/invite/{inv.get('token')}" + inviteUrl = f"{frontendUrl}/invite/{inv.token}" result.append({ - **{k: v for k, v in inv.items() if not k.startswith("_")}, + **inv.model_dump(), "inviteUrl": inviteUrl, - "isExpired": inv.get("expiresAt", 0) < currentTime, - "isUsedUp": inv.get("currentUses", 0) >= inv.get("maxUses", 1) + "isExpired": expiresAt < currentTime, + "isUsedUp": currentUses >= maxUses }) return result @@ -425,29 +408,24 @@ async def revoke_invitation( try: rootInterface = getRootInterface() - # Get invitation - invitationRecords = rootInterface.db.getRecordset( - Invitation, - recordFilter={"id": invitationId} - ) + # Get invitation (Pydantic model) + invitation = rootInterface.getInvitation(invitationId) - if not invitationRecords: + if not invitation: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Invitation '{invitationId}' not found" ) - invitation = invitationRecords[0] - # Verify mandate access - if str(invitation.get("mandateId")) != str(context.mandateId): + if str(invitation.mandateId) != str(context.mandateId): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access denied to this invitation" ) # Already revoked? - if invitation.get("revokedAt"): + if invitation.revokedAt: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invitation is already revoked" @@ -496,13 +474,10 @@ async def validate_invitation( try: rootInterface = getRootInterface() - # Find invitation by token - invitationRecords = rootInterface.db.getRecordset( - Invitation, - recordFilter={"token": token} - ) + # Find invitation by token (Pydantic model) + invitation = rootInterface.getInvitationByToken(token) - if not invitationRecords: + if not invitation: return InvitationValidation( valid=False, reason="Invitation not found", @@ -511,10 +486,8 @@ async def validate_invitation( roleIds=[] ) - invitation = invitationRecords[0] - # Check if revoked - if invitation.get("revokedAt"): + if invitation.revokedAt: return InvitationValidation( valid=False, reason="Invitation has been revoked", @@ -525,7 +498,8 @@ async def validate_invitation( # Check if expired currentTime = getUtcTimestamp() - if invitation.get("expiresAt", 0) < currentTime: + expiresAt = invitation.expiresAt or 0 + if expiresAt < currentTime: return InvitationValidation( valid=False, reason="Invitation has expired", @@ -535,7 +509,9 @@ async def validate_invitation( ) # Check if used up - if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1): + currentUses = invitation.currentUses or 0 + maxUses = invitation.maxUses or 1 + if currentUses >= maxUses: return InvitationValidation( valid=False, reason="Invitation has reached maximum uses", @@ -545,34 +521,29 @@ async def validate_invitation( ) # Get additional info for display - mandateId = invitation.get("mandateId") + mandateId = invitation.mandateId mandateName = None roleLabels = [] - targetUsername = invitation.get("targetUsername") + targetUsername = invitation.targetUsername # Get mandate name - from modules.datamodels.datamodelUam import Mandate - mandateRecords = rootInterface.db.getRecordset( - Mandate, - recordFilter={"id": mandateId} - ) - if mandateRecords: - mandateName = mandateRecords[0].get("name") + mandate = rootInterface.getMandate(str(mandateId)) if mandateId else None + if mandate: + mandateName = mandate.name # Get role names - roleIds = invitation.get("roleIds", []) - from modules.datamodels.datamodelRbac import Role + roleIds = invitation.roleIds or [] for roleId in roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - roleLabels.append(roleRecords[0].get("roleLabel", roleId)) + role = rootInterface.getRole(roleId) + if role: + roleLabels.append(role.roleLabel) return InvitationValidation( valid=True, reason=None, - mandateId=mandateId, + mandateId=str(mandateId) if mandateId else None, mandateName=mandateName, - featureInstanceId=invitation.get("featureInstanceId"), + featureInstanceId=str(invitation.featureInstanceId) if invitation.featureInstanceId else None, roleIds=roleIds, roleLabels=roleLabels, targetUsername=targetUsername @@ -608,42 +579,40 @@ async def accept_invitation( try: rootInterface = getRootInterface() - # Find invitation by token - invitationRecords = rootInterface.db.getRecordset( - Invitation, - recordFilter={"token": token} - ) + # Find invitation by token (Pydantic model) + invitation = rootInterface.getInvitationByToken(token) - if not invitationRecords: + if not invitation: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Invitation not found" ) - invitation = invitationRecords[0] - # Validate invitation - if invitation.get("revokedAt"): + if invitation.revokedAt: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invitation has been revoked" ) currentTime = getUtcTimestamp() - if invitation.get("expiresAt", 0) < currentTime: + expiresAt = invitation.expiresAt or 0 + if expiresAt < currentTime: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invitation has expired" ) - if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1): + currentUses = invitation.currentUses or 0 + maxUses = invitation.maxUses or 1 + if currentUses >= maxUses: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invitation has reached maximum uses" ) # Validate username matches - the invitation is bound to a specific user - targetUsername = invitation.get("targetUsername") + targetUsername = invitation.targetUsername if targetUsername and currentUser.username != targetUsername: logger.warning( f"User {currentUser.username} tried to accept invitation meant for {targetUsername}" @@ -653,9 +622,9 @@ async def accept_invitation( detail=f"Diese Einladung ist für Benutzer '{targetUsername}' bestimmt" ) - mandateId = invitation.get("mandateId") - roleIds = invitation.get("roleIds", []) - featureInstanceId = invitation.get("featureInstanceId") + mandateId = str(invitation.mandateId) if invitation.mandateId else None + roleIds = invitation.roleIds or [] + featureInstanceId = str(invitation.featureInstanceId) if invitation.featureInstanceId else None # Check if user is already a member existingMembership = rootInterface.getUserMandate(str(currentUser.id), mandateId) @@ -744,22 +713,19 @@ def _hasMandateAdminRole(context: RequestContext) -> bool: try: rootInterface = getRootInterface() - from modules.datamodels.datamodelRbac import Role for roleId in context.roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - roleLabel = role.get("roleLabel", "") - # Admin role at mandate level - if roleLabel == "admin" and role.get("mandateId") and not role.get("featureInstanceId"): + role = rootInterface.getRole(roleId) + if role: + # Admin role at mandate level (not feature-instance level) + if role.roleLabel == "admin" and role.mandateId and not role.featureInstanceId: return True return False except Exception as e: logger.error(f"Error checking mandate admin role: {e}") - return False + return False # Fail-safe: no access on error def _isInstanceRole(interface, roleId: str, featureInstanceId: str) -> bool: @@ -767,11 +733,9 @@ def _isInstanceRole(interface, roleId: str, featureInstanceId: str) -> bool: Check if a role belongs to a specific feature instance. """ try: - from modules.datamodels.datamodelRbac import Role - roleRecords = interface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - return str(role.get("featureInstanceId", "")) == str(featureInstanceId) + role = interface.getRole(roleId) + if role: + return str(role.featureInstanceId or "") == str(featureInstanceId) return False except Exception: - return False + return False # Fail-safe: assume not instance role on error diff --git a/modules/routes/routeMessaging.py b/modules/routes/routeMessaging.py index 753fb16f..419e9ae6 100644 --- a/modules/routes/routeMessaging.py +++ b/modules/routes/routeMessaging.py @@ -421,10 +421,9 @@ def _hasTriggerPermission(context: RequestContext) -> bool: rootInterface = getRootInterface() for roleId in context.roleIds: - roleRecords = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roleRecords: - role = roleRecords[0] - roleLabel = role.get("roleLabel", "") + role = rootInterface.getRole(roleId) + if role: + roleLabel = role.roleLabel # Admin role at mandate level or system admin if roleLabel in ("admin", "sysadmin"): return True diff --git a/modules/routes/routeNotifications.py b/modules/routes/routeNotifications.py index 7c8cf9ad..4fc09ac4 100644 --- a/modules/routes/routeNotifications.py +++ b/modules/routes/routeNotifications.py @@ -137,23 +137,19 @@ async def getNotifications( # Build filter recordFilter = {"userId": str(currentUser.id)} - if status: - recordFilter["status"] = status - if type: - recordFilter["type"] = type - - # Get notifications - notifications = rootInterface.db.getRecordset( - model_class=UserNotification, - recordFilter=recordFilter + # Get notifications (Pydantic models, sorted and limited) + notifications = rootInterface.getNotificationsByUser( + userId=str(currentUser.id), + status=status, + limit=limit ) - # Sort by creation date (newest first) and limit - notifications = sorted(notifications, key=lambda x: x.get("createdAt", 0), reverse=True) - if limit: - notifications = notifications[:limit] + # Apply type filter if needed (not common, so filter post-fetch) + if type: + notifications = [n for n in notifications if n.type == type] - return notifications + # Convert to dicts for response + return [n.model_dump() for n in notifications] except Exception as e: logger.error(f"Error getting notifications: {e}") @@ -176,12 +172,10 @@ async def getUnreadCount( try: rootInterface = getRootInterface() - notifications = rootInterface.db.getRecordset( - model_class=UserNotification, - recordFilter={ - "userId": str(currentUser.id), - "status": NotificationStatus.UNREAD.value - } + # Get unread notifications (Pydantic models) + notifications = rootInterface.getNotificationsByUser( + userId=str(currentUser.id), + status=NotificationStatus.UNREAD.value ) return UnreadCountResponse(count=len(notifications)) @@ -207,22 +201,17 @@ async def markAsRead( try: rootInterface = getRootInterface() - # Get the notification - notifications = rootInterface.db.getRecordset( - model_class=UserNotification, - recordFilter={"id": notificationId} - ) + # Get the notification (Pydantic model) + notification = rootInterface.getNotification(notificationId) - if not notifications: + if not notification: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Notification not found" ) - notification = notifications[0] - # Verify ownership - if notification.get("userId") != currentUser.id: + if str(notification.userId) != str(currentUser.id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to access this notification" @@ -262,13 +251,10 @@ async def markAllAsRead( try: rootInterface = getRootInterface() - # Get all unread notifications - notifications = rootInterface.db.getRecordset( - model_class=UserNotification, - recordFilter={ - "userId": currentUser.id, - "status": NotificationStatus.UNREAD.value - } + # Get all unread notifications (Pydantic models) + notifications = rootInterface.getNotificationsByUser( + userId=str(currentUser.id), + status=NotificationStatus.UNREAD.value ) currentTime = getUtcTimestamp() @@ -277,7 +263,7 @@ async def markAllAsRead( for notification in notifications: rootInterface.db.recordModify( model_class=UserNotification, - recordId=notification.get("id"), + recordId=str(notification.id), record={ "status": NotificationStatus.READ.value, "readAt": currentTime @@ -309,37 +295,32 @@ async def executeAction( try: rootInterface = getRootInterface() - # Get the notification - notifications = rootInterface.db.getRecordset( - model_class=UserNotification, - recordFilter={"id": notificationId} - ) + # Get the notification (Pydantic model) + notification = rootInterface.getNotification(notificationId) - if not notifications: + if not notification: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Notification not found" ) - notification = notifications[0] - # Verify ownership - if notification.get("userId") != currentUser.id: + if str(notification.userId) != str(currentUser.id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to access this notification" ) # Check if already actioned - if notification.get("status") == NotificationStatus.ACTIONED.value: + if notification.status == NotificationStatus.ACTIONED.value: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Notification has already been actioned" ) # Validate action exists - actions = notification.get("actions", []) - validActionIds = [a.get("actionId") if isinstance(a, dict) else a.actionId for a in (actions or [])] + actions = notification.actions or [] + validActionIds = [a.get("actionId") if isinstance(a, dict) else a.actionId for a in actions] if actionRequest.actionId not in validActionIds: raise HTTPException( @@ -407,22 +388,17 @@ async def _handleInvitationAction( detail="No invitation reference found" ) - # Get the invitation - invitations = rootInterface.db.getRecordset( - model_class=Invitation, - recordFilter={"id": invitationId} - ) + # Get the invitation (Pydantic model) + invitation = rootInterface.getInvitation(invitationId) - if not invitations: + if not invitation: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Invitation not found" ) - invitation = invitations[0] - # Verify username matches - if invitation.get("targetUsername") != currentUser.username: + if invitation.targetUsername != currentUser.username: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="This invitation is for a different user" @@ -430,19 +406,22 @@ async def _handleInvitationAction( # Check if invitation is still valid currentTime = getUtcTimestamp() - if invitation.get("expiresAt", 0) < currentTime: + expiresAt = invitation.expiresAt or 0 + if expiresAt < currentTime: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invitation has expired" ) - if invitation.get("revokedAt"): + if invitation.revokedAt: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invitation has been revoked" ) - if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1): + currentUses = invitation.currentUses or 0 + maxUses = invitation.maxUses or 1 + if currentUses >= maxUses: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invitation has reached maximum uses" @@ -450,59 +429,34 @@ async def _handleInvitationAction( if actionId == "accept": # Accept the invitation - assign roles and mandate access - mandateId = invitation.get("mandateId") - roleIds = invitation.get("roleIds", []) + mandateId = str(invitation.mandateId) if invitation.mandateId else None + roleIds = list(invitation.roleIds or []) # Ensure user gets the system "user" role for access to public UI elements (e.g. playground) - userRoles = rootInterface.db.getRecordset( - model_class=Role, - recordFilter={"roleLabel": "user"} - ) - if userRoles: - userRoleId = userRoles[0].get("id") + userRole = rootInterface.getRoleByLabel("user") + if userRole: + userRoleId = str(userRole.id) if userRoleId and userRoleId not in roleIds: roleIds = roleIds + [userRoleId] logger.debug(f"Added system 'user' role {userRoleId} to invitation roles") # Get mandate name for result message - mandates = rootInterface.db.getRecordset( - model_class=Mandate, - recordFilter={"id": mandateId} - ) - mandateName = mandates[0].get("mandateLabel", mandateId) if mandates else mandateId + mandate = rootInterface.getMandate(mandateId) if mandateId else None + mandateName = mandate.mandateLabel if mandate and mandate.mandateLabel else mandateId # Check if user already has this mandate - existingMemberships = rootInterface.db.getRecordset( - model_class=UserMandate, - recordFilter={ - "userId": currentUser.id, - "mandateId": mandateId - } - ) + existingMembership = rootInterface.getUserMandate(str(currentUser.id), mandateId) if mandateId else None - if existingMemberships: - # Update existing membership with new roles - existingMembership = existingMemberships[0] - existingRoles = existingMembership.get("roleIds", []) - mergedRoles = list(set(existingRoles + roleIds)) - - rootInterface.db.recordModify( - model_class=UserMandate, - recordId=existingMembership.get("id"), - record={"roleIds": mergedRoles} - ) - logger.info(f"Updated UserMandate for user {currentUser.id} in mandate {mandateId}") + if existingMembership: + # Update existing membership with new roles via interface + # Note: roleIds on UserMandate is deprecated - roles should be assigned via UserMandateRole + logger.info(f"User {currentUser.id} already has membership in mandate {mandateId}, adding roles via UserMandateRole") + # Add roles via junction table + for roleId in roleIds: + rootInterface.addRoleToUserMandate(str(existingMembership.id), roleId) else: - # Create new user-mandate relationship - userMandate = UserMandate( - userId=currentUser.id, - mandateId=mandateId, - roleIds=roleIds - ) - rootInterface.db.recordCreate( - model_class=UserMandate, - record=userMandate.model_dump() - ) + # Create new user-mandate relationship via interface + rootInterface.createUserMandate(str(currentUser.id), mandateId, roleIds) logger.info(f"Created UserMandate for user {currentUser.id} in mandate {mandateId}") # Mark invitation as used @@ -510,9 +464,9 @@ async def _handleInvitationAction( model_class=Invitation, recordId=invitationId, record={ - "usedBy": currentUser.id, + "usedBy": str(currentUser.id), "usedAt": currentTime, - "currentUses": invitation.get("currentUses", 0) + 1 + "currentUses": currentUses + 1 } ) @@ -545,22 +499,17 @@ async def deleteNotification( try: rootInterface = getRootInterface() - # Get the notification - notifications = rootInterface.db.getRecordset( - model_class=UserNotification, - recordFilter={"id": notificationId} - ) + # Get the notification (Pydantic model) + notification = rootInterface.getNotification(notificationId) - if not notifications: + if not notification: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Notification not found" ) - notification = notifications[0] - # Verify ownership - if notification.get("userId") != currentUser.id: + if str(notification.userId) != str(currentUser.id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to delete this notification" diff --git a/modules/routes/routeSecurityAdmin.py b/modules/routes/routeSecurityAdmin.py index 36388c9f..75490eac 100644 --- a/modules/routes/routeSecurityAdmin.py +++ b/modules/routes/routeSecurityAdmin.py @@ -125,8 +125,8 @@ async def list_tokens( if statusFilter: recordFilter["status"] = statusFilter # MULTI-TENANT: SysAdmin sees ALL tokens (no mandate filter) - - tokens = appInterface.db.getRecordset(Token, recordFilter=recordFilter) + # Use interface method to get tokens with flexible filtering + tokens = appInterface.getAllTokens(recordFilter=recordFilter) return tokens except HTTPException: raise @@ -254,15 +254,13 @@ async def revoke_tokens_by_mandate( # MULTI-TENANT: SysAdmin can revoke tokens for any mandate appInterface = getRootInterface() - # Get all UserMandate entries for this mandate to find users - # Note: In new model, users are linked via UserMandate, not User.mandateId - from modules.datamodels.datamodelMembership import UserMandate - userMandates = appInterface.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}) + # Get all UserMandate entries for this mandate to find users using interface method + userMandates = appInterface.getUserMandatesByMandate(mandateId) total = 0 for um in userMandates: total += appInterface.revokeTokensByUser( - userId=um["userId"], + userId=um.userId, authority=AuthAuthority(authority) if authority else None, mandateId=None, # Revoke all tokens for user revokedBy=currentUser.id, diff --git a/modules/routes/routeSecurityGoogle.py b/modules/routes/routeSecurityGoogle.py index a4795243..d8ef3bef 100644 --- a/modules/routes/routeSecurityGoogle.py +++ b/modules/routes/routeSecurityGoogle.py @@ -15,7 +15,7 @@ import httpx from modules.shared.configuration import APP_CONFIG from modules.interfaces.interfaceDbApp import getInterface, getRootInterface from modules.datamodels.datamodelUam import AuthAuthority, User, ConnectionStatus, UserConnection -from modules.auth import getCurrentUser, limiter +from modules.auth import getCurrentUser, limiter, SECRET_KEY, ALGORITHM from modules.auth import createAccessToken, setAccessTokenCookie, createRefreshToken, setRefreshTokenCookie from modules.auth.tokenManager import TokenManager from modules.shared.timeUtils import createExpirationTimestamp, getUtcTimestamp, parseTimestamp @@ -171,10 +171,9 @@ async def login( try: if connectionId: rootInterface = getRootInterface() - records = rootInterface.db.getRecordset(UserConnection, recordFilter={"id": connectionId}) - if records: - record = records[0] - login_hint = record.get("externalEmail") or record.get("externalUsername") + connection = rootInterface.getUserConnectionById(connectionId) + if connection: + login_hint = connection.externalEmail or connection.externalUsername if login_hint: extra_params["login_hint"] = login_hint if "@" in login_hint: @@ -260,23 +259,20 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo rootInterface = getRootInterface() # Prefer connection flow reuse; fallback to user access token if connection_id: - existing_tokens = rootInterface.db.getRecordset(Token, recordFilter={ - "connectionId": connection_id, - "authority": AuthAuthority.GOOGLE - }) + existing_tokens = rootInterface.getTokensByConnectionIdAndAuthority( + connection_id, AuthAuthority.GOOGLE + ) if existing_tokens: # Use most recent by createdAt - existing_tokens.sort(key=lambda x: parseTimestamp(x.get("createdAt"), default=0), reverse=True) - token_response["refresh_token"] = existing_tokens[0].get("tokenRefresh", "") + existing_tokens.sort(key=lambda x: parseTimestamp(x.createdAt, default=0), reverse=True) + token_response["refresh_token"] = existing_tokens[0].tokenRefresh or "" if not token_response.get("refresh_token") and user_id: - existing_access_tokens = rootInterface.db.getRecordset(Token, recordFilter={ - "userId": user_id, - "connectionId": None, - "authority": AuthAuthority.GOOGLE - }) + existing_access_tokens = rootInterface.getTokensByUserIdNoConnection( + user_id, AuthAuthority.GOOGLE + ) if existing_access_tokens: - existing_access_tokens.sort(key=lambda x: parseTimestamp(x.get("createdAt"), default=0), reverse=True) - token_response["refresh_token"] = existing_access_tokens[0].get("tokenRefresh", "") + existing_access_tokens.sort(key=lambda x: parseTimestamp(x.createdAt, default=0), reverse=True) + token_response["refresh_token"] = existing_access_tokens[0].tokenRefresh or "" except Exception: # Non-fatal; continue without refresh token pass diff --git a/modules/routes/routeSecurityLocal.py b/modules/routes/routeSecurityLocal.py index 8f11a9af..5f132833 100644 --- a/modules/routes/routeSecurityLocal.py +++ b/modules/routes/routeSecurityLocal.py @@ -330,40 +330,34 @@ Falls Sie sich nicht registriert haben, können Sie diese E-Mail ignorieren.""" from modules.datamodels.datamodelUam import Mandate currentTime = getUtcTimestamp() - pendingInvitations = appInterface.db.getRecordset( - model_class=Invitation, - recordFilter={"targetUsername": userData.username} - ) + pendingInvitations = appInterface.getInvitationsByTargetUsername(userData.username) for invitation in pendingInvitations: # Skip expired, revoked, or fully used invitations - if invitation.get("expiresAt", 0) < currentTime: + if (invitation.expiresAt or 0) < currentTime: continue - if invitation.get("revokedAt"): + if invitation.revokedAt: continue - if invitation.get("currentUses", 0) >= invitation.get("maxUses", 1): + if (invitation.currentUses or 0) >= (invitation.maxUses or 1): continue - # Get mandate name for notification - mandateId = invitation.get("mandateId") - mandateRecords = appInterface.db.getRecordset( - Mandate, - recordFilter={"id": mandateId} - ) - mandateName = mandateRecords[0].get("mandateLabel", "PowerOn") if mandateRecords else "PowerOn" + # Get mandate name for notification using interface method + mandateId = invitation.mandateId + mandate = appInterface.getMandate(mandateId) + mandateName = mandate.mandateLabel if mandate else "PowerOn" # Get inviter name - inviterId = invitation.get("createdBy") + inviterId = invitation.createdBy inviter = appInterface.getUserById(inviterId) if inviterId else None inviterName = (inviter.fullName or inviter.username) if inviter else "PowerOn" createInvitationNotification( userId=str(user.id), - invitationId=str(invitation.get("id")), + invitationId=str(invitation.id), mandateName=mandateName, inviterName=inviterName ) - logger.info(f"Created notification for new user {userData.username} for invitation {invitation.get('id')}") + logger.info(f"Created notification for new user {userData.username} for invitation {invitation.id}") except Exception as notifErr: logger.warning(f"Failed to create notifications for pending invitations: {notifErr}") diff --git a/modules/routes/routeSecurityMsft.py b/modules/routes/routeSecurityMsft.py index 921ddafe..68bf6fe8 100644 --- a/modules/routes/routeSecurityMsft.py +++ b/modules/routes/routeSecurityMsft.py @@ -16,7 +16,7 @@ from modules.shared.configuration import APP_CONFIG from modules.interfaces.interfaceDbApp import getInterface, getRootInterface from modules.datamodels.datamodelUam import AuthAuthority, User, ConnectionStatus, UserConnection from modules.datamodels.datamodelSecurity import Token -from modules.auth import getCurrentUser, limiter +from modules.auth import getCurrentUser, limiter, SECRET_KEY, ALGORITHM from modules.auth import createAccessToken, setAccessTokenCookie, createRefreshToken, setRefreshTokenCookie from modules.auth.tokenManager import TokenManager from modules.shared.timeUtils import createExpirationTimestamp, getUtcTimestamp, parseTimestamp @@ -97,11 +97,10 @@ async def login( if connectionId: try: rootInterface = getRootInterface() - # Fetch the connection by ID directly - records = rootInterface.db.getRecordset(UserConnection, recordFilter={"id": connectionId}) - if records: - record = records[0] - login_hint = record.get("externalEmail") or record.get("externalUsername") + # Fetch the connection by ID directly using interface method + connection = rootInterface.getUserConnectionById(connectionId) + if connection: + login_hint = connection.externalEmail or connection.externalUsername if login_hint: login_kwargs["login_hint"] = login_hint # Derive domain hint from email/UPN diff --git a/modules/routes/routeSystem.py b/modules/routes/routeSystem.py index 4e2f9f8f..04e14063 100644 --- a/modules/routes/routeSystem.py +++ b/modules/routes/routeSystem.py @@ -38,13 +38,13 @@ def _getUserRoleIds(userId: str) -> List[str]: rootInterface = getRootInterface() roleIds = [] - userMandates = rootInterface.db.getRecordset( - UserMandate, - recordFilter={"userId": userId, "enabled": True} - ) + # Get UserMandates as Pydantic models + userMandates = rootInterface.getUserMandates(userId) for um in userMandates: - mandateRoleIds = rootInterface.getRoleIdsForUserMandate(um.get("id")) + if not um.enabled: + continue + mandateRoleIds = rootInterface.getRoleIdsForUserMandate(str(um.id)) for rid in mandateRoleIds: if rid not in roleIds: roleIds.append(rid) @@ -60,30 +60,24 @@ def _checkUiPermission(roleIds: List[str], objectKey: str) -> bool: rootInterface = getRootInterface() for roleId in roleIds: - # Get UI rules for this role - rules = rootInterface.db.getRecordset( - AccessRule, - recordFilter={"roleId": roleId, "context": "UI"} - ) + # Get UI rules for this role (returns Pydantic AccessRule models) + rules = rootInterface.getAccessRules(roleId=roleId, context=AccessRuleContext.UI) for rule in rules: - ruleItem = rule.get("item") - ruleView = rule.get("view", False) - - if not ruleView: + if not rule.view: continue # Global rule (item=None) grants access to all UI - if ruleItem is None: + if rule.item is None: return True # Exact match - if ruleItem == objectKey: + if rule.item == objectKey: return True # Wildcard match (e.g., ui.system.* matches ui.system.playground) - if ruleItem.endswith(".*"): - prefix = ruleItem[:-2] + if rule.item.endswith(".*"): + prefix = rule.item[:-2] if objectKey.startswith(prefix): return True @@ -108,6 +102,12 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]: elif featureCode == "realestate": from modules.features.realestate.mainRealEstate import UI_OBJECTS return UI_OBJECTS + elif featureCode == "chatplayground": + from modules.features.chatplayground.mainChatplayground import UI_OBJECTS + return UI_OBJECTS + elif featureCode == "automation": + from modules.features.automation.mainAutomation import UI_OBJECTS + return UI_OBJECTS else: logger.warning(f"Unknown feature code: {featureCode}") return [] @@ -287,67 +287,50 @@ def _getInstanceViewPermissions( permissions = {"_all": False, "isAdmin": False} try: - from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role + # Get FeatureAccess for this user and instance (Pydantic model) + featureAccess = rootInterface.getFeatureAccess(userId, instanceId) - # Get FeatureAccess for this user and instance - featureAccesses = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": userId, "featureInstanceId": instanceId} - ) - - if not featureAccesses: + if not featureAccess: return permissions - # Get role IDs via FeatureAccessRole junction table - featureAccessId = featureAccesses[0].get("id") - featureAccessRoles = rootInterface.db.getRecordset( - FeatureAccessRole, - recordFilter={"featureAccessId": featureAccessId} - ) - roleIds = [far.get("roleId") for far in featureAccessRoles] + # Get role IDs via interface method + roleIds = rootInterface.getRoleIdsForFeatureAccess(str(featureAccess.id)) if not roleIds: return permissions # Check if user has admin role for roleId in roleIds: - roles = rootInterface.db.getRecordset(Role, recordFilter={"id": roleId}) - if roles: - roleLabel = roles[0].get("roleLabel", "").lower() - if "admin" in roleLabel: - permissions["isAdmin"] = True - break + role = rootInterface.getRole(roleId) + if role and "admin" in role.roleLabel.lower(): + permissions["isAdmin"] = True + break - # Get UI permissions from AccessRules - # Permissions are stored with full objectKey (e.g., ui.feature.trustee.dashboard) + # Get UI permissions from AccessRules (Pydantic models) for roleId in roleIds: - accessRules = rootInterface.db.getRecordset( - AccessRule, - recordFilter={"roleId": roleId, "context": "UI"} - ) + accessRules = rootInterface.getAccessRules(roleId=roleId, context=AccessRuleContext.UI) logger.debug(f"_getInstanceViewPermissions: roleId={roleId}, UI rules count={len(accessRules)}") for rule in accessRules: - if not rule.get("view", False): + if not rule.view: continue - item = rule.get("item") - logger.debug(f"_getInstanceViewPermissions: rule item={item}, view={rule.get('view')}") + logger.debug(f"_getInstanceViewPermissions: rule item={rule.item}, view={rule.view}") - if item is None: + if rule.item is None: # item=None means all views permissions["_all"] = True else: # Store full objectKey as per Navigation-API-Konzept - permissions[item] = True + permissions[rule.item] = True logger.debug(f"_getInstanceViewPermissions: final permissions={permissions}") return permissions except Exception as e: logger.debug(f"Error getting instance view permissions: {e}") - return permissions + return permissions # Fail-safe: no permissions on error def _buildStaticBlocks( diff --git a/modules/security/rbac.py b/modules/security/rbac.py index 11d6e11d..c661e795 100644 --- a/modules/security/rbac.py +++ b/modules/security/rbac.py @@ -173,7 +173,7 @@ class RbacClass: try: # Get Root mandate ID (first mandate in system) allMandates = self.dbApp.getRecordset(Mandate) - rootMandateId = allMandates[0].get("id") if allMandates else None + rootMandateId = allMandates[0]["id"] if allMandates else None # Collect mandates to check: # - If mandateId provided: current mandate + Root mandate (if different) @@ -186,21 +186,21 @@ class RbacClass: # Load roles from each mandate for checkMandateId in mandatesToCheck: - userMandates = self.dbApp.getRecordset( + userMandateRecords = self.dbApp.getRecordset( UserMandate, recordFilter={"userId": user.id, "mandateId": checkMandateId, "enabled": True} ) - if userMandates: - userMandateId = userMandates[0].get("id") + if userMandateRecords: + userMandateId = userMandateRecords[0]["id"] # Lade UserMandateRoles (Mandate-level roles) - userMandateRoles = self.dbApp.getRecordset( + userMandateRoleRecords = self.dbApp.getRecordset( UserMandateRole, recordFilter={"userMandateId": userMandateId} ) - foundRoles = [r.get("roleId") for r in userMandateRoles if r.get("roleId")] + foundRoles = [r["roleId"] for r in userMandateRoleRecords if r.get("roleId")] roleIds.update(foundRoles) # Load FeatureAccess + FeatureAccessRole (Instance-level roles) @@ -215,14 +215,14 @@ class RbacClass: ) if featureAccessRecords: - featureAccessId = featureAccessRecords[0].get("id") + featureAccessId = featureAccessRecords[0]["id"] - featureAccessRoles = self.dbApp.getRecordset( + featureAccessRoleRecords = self.dbApp.getRecordset( FeatureAccessRole, recordFilter={"featureAccessId": featureAccessId} ) - roleIds.update([r.get("roleId") for r in featureAccessRoles if r.get("roleId")]) + roleIds.update([r["roleId"] for r in featureAccessRoleRecords if r.get("roleId")]) except Exception as e: logger.error(f"Error loading role IDs for user {user.id}: {e}") @@ -377,12 +377,14 @@ class RbacClass: if not roleRecords: continue - role = roleRecords[0] + # Convert to Pydantic model for type-safe access + roleDict = {k: v for k, v in roleRecords[0].items() if not k.startswith("_")} + role = Role(**roleDict) # Bestimme Priorität basierend auf Role-Scope - if role.get("featureInstanceId"): + if role.featureInstanceId: priority = 3 # Instance-specific - elif role.get("mandateId"): + elif role.mandateId: priority = 2 # Mandate-specific else: priority = 1 # Global diff --git a/modules/services/serviceChat/mainServiceChat.py b/modules/services/serviceChat/mainServiceChat.py index 137dcd05..b7910720 100644 --- a/modules/services/serviceChat/mainServiceChat.py +++ b/modules/services/serviceChat/mainServiceChat.py @@ -681,7 +681,7 @@ class ChatService: "workflowId": workflow.id, "process": process, "engine": aiResponse.modelName, - "priceUsd": aiResponse.priceUsd, + "priceCHF": aiResponse.priceCHF, "processingTime": aiResponse.processingTime, "bytesSent": aiResponse.bytesSent, "bytesReceived": aiResponse.bytesReceived, diff --git a/modules/services/serviceExtraction/mainServiceExtraction.py b/modules/services/serviceExtraction/mainServiceExtraction.py index 4081158d..9ee9e739 100644 --- a/modules/services/serviceExtraction/mainServiceExtraction.py +++ b/modules/services/serviceExtraction/mainServiceExtraction.py @@ -39,7 +39,7 @@ class ExtractionService: # Verify required internal model is available (used for pricing in extractContent) modelDisplayName = "Internal Document Extractor" model = modelRegistry.getModel(modelDisplayName) - if model is None or model.calculatePriceUsd is None: + if model is None or model.calculatepriceCHF is None: raise RuntimeError(f"FATAL: Required internal model '{modelDisplayName}' is not available. Check connector registration.") def extractContent( @@ -218,18 +218,18 @@ class ExtractionService: modelDisplayName = "Internal Document Extractor" model = modelRegistry.getModel(modelDisplayName) # Hard fail if model is missing; caller must ensure connectors are registered - if model is None or model.calculatePriceUsd is None: + if model is None or model.calculatepriceCHF is None: if docOperationId: self.services.chat.progressLogFinish(docOperationId, False) raise RuntimeError(f"Pricing model not available: {modelDisplayName}") - priceUsd = model.calculatePriceUsd(processingTime, bytesSent, bytesReceived) + priceCHF = model.calculatepriceCHF(processingTime, bytesSent, bytesReceived) # Create AiCallResponse with real calculation # Use model.name for the response (API identifier), not displayName aiResponse = AiCallResponse( content="", # No content for extraction stats needed modelName=model.name, - priceUsd=priceUsd, + priceCHF=priceCHF, processingTime=processingTime, bytesSent=bytesSent, bytesReceived=bytesReceived, @@ -478,7 +478,7 @@ class ExtractionService: "resultSize": len(response.content), "typeGroup": part.typeGroup, "modelName": response.modelName, - "priceUsd": response.priceUsd + "priceCHF": response.priceCHF } ) @@ -606,7 +606,7 @@ class ExtractionService: "originalIndex": i, # Phase 7: Explicit order index "processingOrder": i, # Phase 7: Processing order "modelName": result.modelName, - "priceUsd": result.priceUsd, + "priceCHF": result.priceCHF, "processingTime": result.processingTime, "bytesSent": result.bytesSent, "bytesReceived": result.bytesReceived @@ -1311,7 +1311,7 @@ class ExtractionService: return AiCallResponse( content=modelResponse.content, modelName=model.name, - priceUsd=0.0, + priceCHF=0.0, processingTime=processingTime, bytesSent=0, bytesReceived=0, @@ -1416,7 +1416,7 @@ class ExtractionService: return AiCallResponse( content=mergedContent, modelName=model.name, - priceUsd=sum(r.priceUsd for r in chunkResults), + priceCHF=sum(r.priceCHF for r in chunkResults), processingTime=sum(r.processingTime for r in chunkResults), bytesSent=sum(r.bytesSent for r in chunkResults), bytesReceived=sum(r.bytesReceived for r in chunkResults), @@ -1465,7 +1465,7 @@ class ExtractionService: return AiCallResponse( content=mergedContent, modelName=model.name, - priceUsd=sum(r.priceUsd for r in chunkResults), + priceCHF=sum(r.priceCHF for r in chunkResults), processingTime=sum(r.processingTime for r in chunkResults), bytesSent=sum(r.bytesSent for r in chunkResults), bytesReceived=sum(r.bytesReceived for r in chunkResults), @@ -1492,7 +1492,7 @@ class ExtractionService: return AiCallResponse( content=errorMsg, modelName="error", - priceUsd=0.0, + priceCHF=0.0, processingTime=0.0, bytesSent=inputBytes, bytesReceived=outputBytes, @@ -1622,7 +1622,7 @@ class ExtractionService: return AiCallResponse( content=mergedContent, modelName="multiple", - priceUsd=sum(r.priceUsd for r in allResults), + priceCHF=sum(r.priceCHF for r in allResults), processingTime=sum(r.processingTime for r in allResults), bytesSent=sum(r.bytesSent for r in allResults), bytesReceived=sum(r.bytesReceived for r in allResults), diff --git a/modules/shared/gdprDeletion.py b/modules/shared/gdprDeletion.py index da8a60cf..034b627a 100644 --- a/modules/shared/gdprDeletion.py +++ b/modules/shared/gdprDeletion.py @@ -576,22 +576,16 @@ def _deleteUserDataFromFeatureDatabases(userId: str, currentUser) -> Dict[str, A rootInterface = getRootInterface() - # Get all feature accesses for this user - featureAccesses = rootInterface.db.getRecordset( - FeatureAccess, - recordFilter={"userId": str(userId)} - ) + # Get all feature accesses for this user using interface method + featureAccesses = rootInterface.getFeatureAccessesForUser(str(userId)) # Collect unique feature codes featureCodes: Set[str] = set() for fa in featureAccesses: - instanceId = fa.get("featureInstanceId") - instanceRecords = rootInterface.db.getRecordset( - FeatureInstance, - recordFilter={"id": instanceId} - ) - if instanceRecords: - featureCode = instanceRecords[0].get("featureCode") + instanceId = fa.featureInstanceId + instance = rootInterface.getFeatureInstance(instanceId) + if instance: + featureCode = instance.featureCode if featureCode: featureCodes.add(featureCode) diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py index 113fa903..9b300d78 100644 --- a/modules/system/mainSystem.py +++ b/modules/system/mainSystem.py @@ -25,11 +25,11 @@ FEATURE_ICON = "mdi-cog" # Block Order (gemäss Navigation-API-Konzept): # - System: 10 # - : 15 (wird in routeSystem.py eingefügt) -# - Workflows: 20 # - Basisdaten: 30 -# - Migrate: 40 # - Administration: 200 # +# NOTE: Workflows and Migrate sections removed - now handled as features +# # Item Order: Default-Abstand 10 pro Item # uiComponent: Abgeleitet von objectKey (ui.system.home -> page.system.home) # icon: Wird intern gehalten aber NICHT in der API Response zurückgegeben @@ -60,49 +60,6 @@ NAVIGATION_SECTIONS = [ }, ], }, - { - "id": "workflows", - "title": {"en": "WORKFLOWS", "de": "WORKFLOWS", "fr": "WORKFLOWS"}, - "order": 20, - "items": [ - { - "id": "playground", - "objectKey": "ui.system.playground", - "label": {"en": "Chat Playground", "de": "Chat Playground", "fr": "Chat Playground"}, - "icon": "FaPlay", - "path": "/workflows/playground", - "order": 10, - "public": True, - }, - { - "id": "chats", - "objectKey": "ui.system.chats", - "label": {"en": "Chats", "de": "Chats", "fr": "Chats"}, - "icon": "FaListAlt", - "path": "/workflows/list", - "order": 20, - "public": True, - }, - { - "id": "automations", - "objectKey": "ui.system.automations", - "label": {"en": "Automations", "de": "Automatisierungen", "fr": "Automatisations"}, - "icon": "FaCogs", - "path": "/workflows/automations", - "order": 30, - "public": True, - }, - { - "id": "automation-templates", - "objectKey": "ui.system.automation-templates", - "label": {"en": "Templates", "de": "Vorlagen", "fr": "Modèles"}, - "icon": "FaFileAlt", - "path": "/workflows/automation-templates", - "order": 35, - "public": True, - }, - ], - }, { "id": "basedata", "title": {"en": "BASE DATA", "de": "BASISDATEN", "fr": "DONNÉES DE BASE"}, @@ -134,54 +91,55 @@ NAVIGATION_SECTIONS = [ }, ], }, - { - "id": "migrate", - "title": {"en": "MIGRATE TO FEATURES", "de": "MIGRATE TO FEATURES", "fr": "MIGRER VERS FEATURES"}, - "order": 40, - "deprecated": True, - "items": [ - { - "id": "chatbot", - "objectKey": "ui.system.chatbot", - "label": {"en": "Chatbot", "de": "Chatbot", "fr": "Chatbot"}, - "icon": "FaComments", - "path": "/chatbot", - "order": 10, - "deprecated": True, - }, - { - "id": "pek", - "objectKey": "ui.system.pek", - "label": {"en": "PEK", "de": "PEK", "fr": "PEK"}, - "icon": "FaChartBar", - "path": "/pek", - "order": 20, - "deprecated": True, - }, - { - "id": "speech", - "objectKey": "ui.system.speech", - "label": {"en": "Speech", "de": "Sprache", "fr": "Parole"}, - "icon": "FaMicrophone", - "path": "/speech", - "order": 30, - "deprecated": True, - }, - ], - }, { "id": "admin", "title": {"en": "ADMINISTRATION", "de": "ADMINISTRATION", "fr": "ADMINISTRATION"}, "order": 200, "adminOnly": True, "items": [ + { + "id": "admin-users", + "objectKey": "ui.admin.users", + "label": {"en": "Users", "de": "Benutzer", "fr": "Utilisateurs"}, + "icon": "FaUsers", + "path": "/admin/users", + "order": 10, + "adminOnly": True, + }, + { + "id": "admin-invitations", + "objectKey": "ui.admin.invitations", + "label": {"en": "User Invitations", "de": "Benutzer-Einladungen", "fr": "Invitations utilisateurs"}, + "icon": "FaEnvelopeOpenText", + "path": "/admin/invitations", + "order": 12, + "adminOnly": True, + }, + { + "id": "admin-user-access-overview", + "objectKey": "ui.admin.userAccessOverview", + "label": {"en": "User Access Overview", "de": "Benutzer-Zugriffsübersicht", "fr": "Aperçu des accès utilisateur"}, + "icon": "FaClipboardList", + "path": "/admin/user-access-overview", + "order": 14, + "adminOnly": True, + }, { "id": "admin-mandates", "objectKey": "ui.admin.mandates", "label": {"en": "Mandates", "de": "Mandanten", "fr": "Mandats"}, "icon": "FaBuilding", "path": "/admin/mandates", - "order": 3, + "order": 20, + "adminOnly": True, + }, + { + "id": "admin-user-mandates", + "objectKey": "ui.admin.userMandates", + "label": {"en": "Mandate Members", "de": "Mandanten-Mitglieder", "fr": "Membres du mandat"}, + "icon": "FaUserFriends", + "path": "/admin/user-mandates", + "order": 25, "adminOnly": True, }, { @@ -190,27 +148,36 @@ NAVIGATION_SECTIONS = [ "label": {"en": "Access Management", "de": "Zugriffsverwaltung", "fr": "Gestion des accès"}, "icon": "FaBuilding", "path": "/admin/access", - "order": 5, - "adminOnly": True, - }, - { - "id": "admin-users", - "objectKey": "ui.admin.users", - "label": {"en": "Users & Invitations", "de": "Benutzer & Einladungen", "fr": "Utilisateurs et invitations"}, - "icon": "FaUsers", - "path": "/admin/users", - "order": 10, + "order": 30, "adminOnly": True, }, { "id": "admin-roles", "objectKey": "ui.admin.roles", - "label": {"en": "Roles & Permissions", "de": "Rollen & Berechtigungen", "fr": "Rôles et permissions"}, - "icon": "FaKey", + "label": {"en": "Roles", "de": "Rollen", "fr": "Rôles"}, + "icon": "FaUserTag", "path": "/admin/mandate-roles", "order": 40, "adminOnly": True, }, + { + "id": "admin-mandate-role-permissions", + "objectKey": "ui.admin.mandateRolePermissions", + "label": {"en": "Role Permissions", "de": "Rollen-Berechtigungen", "fr": "Permissions des rôles"}, + "icon": "FaKey", + "path": "/admin/mandate-role-permissions", + "order": 45, + "adminOnly": True, + }, + { + "id": "admin-feature-roles", + "objectKey": "ui.admin.featureRoles", + "label": {"en": "Feature Roles & Permissions", "de": "Features Rollen & Rechte", "fr": "Rôles et droits des features"}, + "icon": "FaShieldAlt", + "path": "/admin/feature-roles", + "order": 50, + "adminOnly": True, + }, ], }, ] diff --git a/modules/workflows/methods/methodAi/actions/process.py b/modules/workflows/methods/methodAi/actions/process.py index b85761b7..752fe7f6 100644 --- a/modules/workflows/methods/methodAi/actions/process.py +++ b/modules/workflows/methods/methodAi/actions/process.py @@ -153,7 +153,7 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult: metadata=AiResponseMetadata( additionalData={ "modelName": aiResponse_obj.modelName, - "priceUsd": aiResponse_obj.priceUsd, + "priceCHF": aiResponse_obj.priceCHF, "processingTime": aiResponse_obj.processingTime, "bytesSent": aiResponse_obj.bytesSent, "bytesReceived": aiResponse_obj.bytesReceived, diff --git a/tests/functional/test02_ai_models.py b/tests/functional/test02_ai_models.py index 12a374f8..f1b2f62f 100644 --- a/tests/functional/test02_ai_models.py +++ b/tests/functional/test02_ai_models.py @@ -628,7 +628,7 @@ Width: {crawlWidth} "hasContent": True, "error": None, "modelUsed": modelName, - "priceUsd": 0.0, + "priceCHF": 0.0, "bytesSent": 0, "bytesReceived": contentLength, "isValidJson": True, From d118128813366e340d6565cb7db2fc2bdb610f28 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Wed, 4 Feb 2026 21:50:55 +0100 Subject: [PATCH 02/18] billing initial --- app.py | 3 + modules/datamodels/datamodelBilling.py | 265 +++++++ .../automation/datamodelFeatureAutomation.py | 2 + modules/interfaces/interfaceBootstrap.py | 113 +++ modules/interfaces/interfaceDbBilling.py | 734 ++++++++++++++++++ modules/routes/routeBilling.py | 557 +++++++++++++ modules/services/serviceBilling/__init__.py | 7 + .../serviceBilling/mainServiceBilling.py | 408 ++++++++++ .../services/serviceChat/mainServiceChat.py | 62 +- modules/system/mainSystem.py | 32 + 10 files changed, 2182 insertions(+), 1 deletion(-) create mode 100644 modules/datamodels/datamodelBilling.py create mode 100644 modules/interfaces/interfaceDbBilling.py create mode 100644 modules/routes/routeBilling.py create mode 100644 modules/services/serviceBilling/__init__.py create mode 100644 modules/services/serviceBilling/mainServiceBilling.py diff --git a/app.py b/app.py index 474de4d6..609d0c07 100644 --- a/app.py +++ b/app.py @@ -503,6 +503,9 @@ app.include_router(userAccessOverviewRouter) from modules.routes.routeGdpr import router as gdprRouter app.include_router(gdprRouter) +from modules.routes.routeBilling import router as billingRouter +app.include_router(billingRouter) + # ============================================================================ # SYSTEM ROUTES (Navigation, etc.) # ============================================================================ diff --git a/modules/datamodels/datamodelBilling.py b/modules/datamodels/datamodelBilling.py new file mode 100644 index 00000000..e7e59eb4 --- /dev/null +++ b/modules/datamodels/datamodelBilling.py @@ -0,0 +1,265 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Billing models: BillingAccount, BillingTransaction, BillingSettings, UsageStatistics.""" + +from typing import List, Dict, Any, Optional +from enum import Enum +from datetime import date, datetime +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import registerModelLabels +import uuid + + +class BillingModelEnum(str, Enum): + """Billing model types.""" + PREPAY_MANDATE = "PREPAY_MANDATE" # Prepaid budget shared by all users in mandate + PREPAY_USER = "PREPAY_USER" # Prepaid budget per user within mandate + CREDIT_POSTPAY = "CREDIT_POSTPAY" # Credit with monthly invoice (requires billing address) + UNLIMITED = "UNLIMITED" # No cost limitation (internal mandates only) + + +class AccountTypeEnum(str, Enum): + """Account type for billing accounts.""" + MANDATE = "MANDATE" # Account for entire mandate + USER = "USER" # Account for specific user within mandate + + +class TransactionTypeEnum(str, Enum): + """Transaction types for billing.""" + CREDIT = "CREDIT" # Credit/top-up (positive) + DEBIT = "DEBIT" # Debit/usage (positive amount, reduces balance) + ADJUSTMENT = "ADJUSTMENT" # Manual adjustment by admin + + +class ReferenceTypeEnum(str, Enum): + """Reference types for transactions.""" + WORKFLOW = "WORKFLOW" # AI workflow usage + PAYMENT = "PAYMENT" # Payment/top-up + ADMIN = "ADMIN" # Admin adjustment + SYSTEM = "SYSTEM" # System credit (e.g., initial credit) + + +class PeriodTypeEnum(str, Enum): + """Period types for usage statistics.""" + DAY = "DAY" + MONTH = "MONTH" + YEAR = "YEAR" + + +class BillingAddress(BaseModel): + """Billing address for CREDIT_POSTPAY mandates.""" + company: str = Field(..., description="Company name") + street: str = Field(..., description="Street and number") + zip: str = Field(..., description="Postal code") + city: str = Field(..., description="City") + country: str = Field(default="CH", description="Country code") + vatNumber: Optional[str] = Field(None, description="VAT number (optional)") + + +registerModelLabels( + "BillingAddress", + {"en": "Billing Address", "de": "Rechnungsadresse"}, + { + "company": {"en": "Company", "de": "Firma"}, + "street": {"en": "Street", "de": "Strasse"}, + "zip": {"en": "ZIP", "de": "PLZ"}, + "city": {"en": "City", "de": "Ort"}, + "country": {"en": "Country", "de": "Land"}, + "vatNumber": {"en": "VAT Number", "de": "MwSt-Nummer"}, + }, +) + + +class BillingAccount(BaseModel): + """Billing account for mandate or user-mandate combination.""" + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) + mandateId: str = Field(..., description="Foreign key to Mandate") + userId: Optional[str] = Field(None, description="Foreign key to User (only for PREPAY_USER)") + accountType: AccountTypeEnum = Field(..., description="Account type: MANDATE or USER") + balance: float = Field(default=0.0, description="Current balance in CHF") + creditLimit: Optional[float] = Field(None, description="Credit limit in CHF (only for CREDIT_POSTPAY)") + warningThreshold: float = Field(default=0.0, description="Warning threshold in CHF") + lastWarningAt: Optional[datetime] = Field(None, description="Last warning sent timestamp") + enabled: bool = Field(default=True, description="Account is active") + + +registerModelLabels( + "BillingAccount", + {"en": "Billing Account", "de": "Abrechnungskonto"}, + { + "id": {"en": "ID", "de": "ID"}, + "mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"}, + "userId": {"en": "User ID", "de": "Benutzer-ID"}, + "accountType": {"en": "Account Type", "de": "Kontotyp"}, + "balance": {"en": "Balance (CHF)", "de": "Guthaben (CHF)"}, + "creditLimit": {"en": "Credit Limit (CHF)", "de": "Kreditlimit (CHF)"}, + "warningThreshold": {"en": "Warning Threshold (CHF)", "de": "Warnschwelle (CHF)"}, + "lastWarningAt": {"en": "Last Warning", "de": "Letzte Warnung"}, + "enabled": {"en": "Enabled", "de": "Aktiv"}, + }, +) + + +class BillingTransaction(BaseModel): + """Single billing transaction (credit, debit, adjustment).""" + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) + accountId: str = Field(..., description="Foreign key to BillingAccount") + transactionType: TransactionTypeEnum = Field(..., description="Transaction type") + amount: float = Field(..., description="Amount in CHF (always positive)") + description: str = Field(..., description="Transaction description") + + # Reference to source + referenceType: Optional[ReferenceTypeEnum] = Field(None, description="Reference type") + referenceId: Optional[str] = Field(None, description="Reference ID") + + # Context for workflow transactions + workflowId: Optional[str] = Field(None, description="Workflow ID (for WORKFLOW transactions)") + featureInstanceId: Optional[str] = Field(None, description="Feature instance ID") + featureCode: Optional[str] = Field(None, description="Feature code (e.g., chatplayground, automation)") + aicoreProvider: Optional[str] = Field(None, description="AICore provider (anthropic, openai, etc.)") + + +registerModelLabels( + "BillingTransaction", + {"en": "Billing Transaction", "de": "Transaktion"}, + { + "id": {"en": "ID", "de": "ID"}, + "accountId": {"en": "Account ID", "de": "Konto-ID"}, + "transactionType": {"en": "Type", "de": "Typ"}, + "amount": {"en": "Amount (CHF)", "de": "Betrag (CHF)"}, + "description": {"en": "Description", "de": "Beschreibung"}, + "referenceType": {"en": "Reference Type", "de": "Referenztyp"}, + "referenceId": {"en": "Reference ID", "de": "Referenz-ID"}, + "workflowId": {"en": "Workflow ID", "de": "Workflow-ID"}, + "featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"}, + "featureCode": {"en": "Feature Code", "de": "Feature-Code"}, + "aicoreProvider": {"en": "AI Provider", "de": "AI-Anbieter"}, + }, +) + + +class BillingSettings(BaseModel): + """Billing settings per mandate.""" + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) + mandateId: str = Field(..., description="Foreign key to Mandate (UNIQUE)") + billingModel: BillingModelEnum = Field(..., description="Billing model") + + # Configuration + defaultUserCredit: float = Field(default=10.0, description="Initial credit in CHF for new users (PREPAY_USER)") + warningThresholdPercent: float = Field(default=10.0, description="Warning threshold as percentage") + blockOnZeroBalance: bool = Field(default=True, description="Block AI features when balance is zero") + + # Billing address (required for CREDIT_POSTPAY) + billingAddress: Optional[BillingAddress] = Field(None, description="Billing address") + + # Notifications + notifyEmails: List[str] = Field(default_factory=list, description="Email addresses for billing notifications") + notifyOnWarning: bool = Field(default=True, description="Send email when warning threshold is reached") + + +registerModelLabels( + "BillingSettings", + {"en": "Billing Settings", "de": "Abrechnungseinstellungen"}, + { + "id": {"en": "ID", "de": "ID"}, + "mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"}, + "billingModel": {"en": "Billing Model", "de": "Abrechnungsmodell"}, + "defaultUserCredit": {"en": "Default User Credit (CHF)", "de": "Standard-Startguthaben (CHF)"}, + "warningThresholdPercent": {"en": "Warning Threshold (%)", "de": "Warnschwelle (%)"}, + "blockOnZeroBalance": {"en": "Block on Zero Balance", "de": "Bei 0 blockieren"}, + "billingAddress": {"en": "Billing Address", "de": "Rechnungsadresse"}, + "notifyEmails": {"en": "Notification Emails", "de": "Benachrichtigungs-Emails"}, + "notifyOnWarning": {"en": "Notify on Warning", "de": "Bei Warnung benachrichtigen"}, + }, +) + + +class UsageStatistics(BaseModel): + """Aggregated usage statistics for quick retrieval.""" + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) + accountId: str = Field(..., description="Foreign key to BillingAccount") + periodType: PeriodTypeEnum = Field(..., description="Period type") + periodStart: date = Field(..., description="Period start date") + + # Aggregated values + totalCostCHF: float = Field(default=0.0, description="Total cost in CHF") + transactionCount: int = Field(default=0, description="Number of transactions") + + # Breakdown by provider + costByProvider: Dict[str, float] = Field( + default_factory=dict, + description="Cost breakdown by provider (e.g., {'anthropic': 12.50, 'openai': 8.30})" + ) + + # Breakdown by feature + costByFeature: Dict[str, float] = Field( + default_factory=dict, + description="Cost breakdown by feature (e.g., {'chatplayground': 15.00, 'automation': 5.80})" + ) + + +registerModelLabels( + "UsageStatistics", + {"en": "Usage Statistics", "de": "Nutzungsstatistik"}, + { + "id": {"en": "ID", "de": "ID"}, + "accountId": {"en": "Account ID", "de": "Konto-ID"}, + "periodType": {"en": "Period Type", "de": "Periodentyp"}, + "periodStart": {"en": "Period Start", "de": "Periodenbeginn"}, + "totalCostCHF": {"en": "Total Cost (CHF)", "de": "Gesamtkosten (CHF)"}, + "transactionCount": {"en": "Transaction Count", "de": "Anzahl Transaktionen"}, + "costByProvider": {"en": "Cost by Provider", "de": "Kosten nach Anbieter"}, + "costByFeature": {"en": "Cost by Feature", "de": "Kosten nach Feature"}, + }, +) + + +# ============================================================================ +# Response Models for API +# ============================================================================ + +class BillingBalanceResponse(BaseModel): + """Response model for balance endpoint.""" + mandateId: str + mandateName: str + billingModel: BillingModelEnum + balance: float + currency: str = "CHF" + warningThreshold: float + isWarning: bool + creditLimit: Optional[float] = None + + +class BillingStatisticsChartData(BaseModel): + """Chart data point for statistics.""" + label: str + totalCost: float + byProvider: Dict[str, float] + + +class BillingStatisticsResponse(BaseModel): + """Response model for statistics endpoint.""" + mandateId: str + period: PeriodTypeEnum + year: int + month: Optional[int] = None + currency: str = "CHF" + data: List[BillingStatisticsChartData] + totals: Dict[str, Any] + + +class BillingCheckResult(BaseModel): + """Result of a billing balance check.""" + allowed: bool + reason: Optional[str] = None + currentBalance: Optional[float] = None + requiredAmount: Optional[float] = None + billingModel: Optional[BillingModelEnum] = None diff --git a/modules/features/automation/datamodelFeatureAutomation.py b/modules/features/automation/datamodelFeatureAutomation.py index 6d1e906f..b6b32c22 100644 --- a/modules/features/automation/datamodelFeatureAutomation.py +++ b/modules/features/automation/datamodelFeatureAutomation.py @@ -25,6 +25,7 @@ class AutomationDefinition(BaseModel): eventId: Optional[str] = Field(None, description="Event ID from event management (None if not registered)", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) status: Optional[str] = Field(None, description="Status: 'active' if event is registered, 'inactive' if not (computed, readonly)", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) executionLogs: List[Dict[str, Any]] = Field(default_factory=list, description="List of execution logs, each containing timestamp, workflowId, status, and messages", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) + allowedProviders: List[str] = Field(default_factory=list, description="List of allowed AICore providers (e.g., 'anthropic', 'openai'). Empty means all RBAC-permitted providers are allowed.", json_schema_extra={"frontend_type": "multiselect", "frontend_readonly": False, "frontend_required": False}) registerModelLabels( @@ -42,6 +43,7 @@ registerModelLabels( "eventId": {"en": "Event ID", "ge": "Event-ID", "fr": "ID de l'événement"}, "status": {"en": "Status", "ge": "Status", "fr": "Statut"}, "executionLogs": {"en": "Execution Logs", "ge": "Ausführungsprotokolle", "fr": "Journaux d'exécution"}, + "allowedProviders": {"en": "Allowed Providers", "ge": "Erlaubte Provider", "fr": "Fournisseurs autorisés"}, }, ) diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index 0b630f85..d6f0f063 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -76,6 +76,10 @@ def initBootstrap(db: DatabaseConnector) -> None: # Initialize feature instances for root mandate if mandateId: initRootMandateFeatures(db, mandateId) + + # Initialize billing settings for root mandate + if mandateId: + initRootMandateBilling(mandateId) def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] = None) -> None: @@ -1192,6 +1196,115 @@ def _createResourceContextRules(db: DatabaseConnector) -> None: db.recordCreate(AccessRule, rule) logger.info(f"Created {len(resourceRules)} RESOURCE context rules") + + # Create AICore provider RBAC rules + _createAicoreProviderRules(db) + + +def _createAicoreProviderRules(db: DatabaseConnector) -> None: + """ + Create RBAC rules for AICore providers (resource.aicore.{provider}). + All roles get access to all providers by default. + + NOTE: Provider list is dynamically discovered from AICore model registry. + + Args: + db: Database connector instance + """ + try: + from modules.aicore.aicoreModelRegistry import modelRegistry + + # Discover available connectors dynamically + connectors = modelRegistry.discoverConnectors() + providers = [c.getConnectorType() for c in connectors] + + if not providers: + logger.warning("No AICore providers discovered, skipping provider RBAC rules") + return + + logger.info(f"Creating RBAC rules for AICore providers: {providers}") + + providerRules = [] + + # All roles get access to all providers (as per requirement) + for roleLabel in ["admin", "user", "viewer"]: + roleId = _getRoleId(db, roleLabel) + if not roleId: + continue + + for provider in providers: + resourceKey = f"resource.aicore.{provider}" + + # Check if rule already exists + existingRules = db.getRecordset( + AccessRule, + recordFilter={ + "roleId": roleId, + "context": AccessRuleContext.RESOURCE.value, + "item": resourceKey + } + ) + + if not existingRules: + providerRules.append(AccessRule( + roleId=roleId, + context=AccessRuleContext.RESOURCE, + item=resourceKey, + view=True, # view=True means "can use" for RESOURCE context + read=None, + create=None, + update=None, + delete=None, + )) + + for rule in providerRules: + db.recordCreate(AccessRule, rule) + + if providerRules: + logger.info(f"Created {len(providerRules)} AICore provider RBAC rules") + else: + logger.debug("All AICore provider RBAC rules already exist") + + except Exception as e: + logger.warning(f"Failed to create AICore provider RBAC rules: {e}") + + +def initRootMandateBilling(mandateId: str) -> None: + """ + Initialize billing settings for root mandate. + Root mandate uses PREPAY_USER model with 10 CHF initial credit per user. + + Args: + mandateId: Root mandate ID + """ + try: + from modules.interfaces.interfaceDbBilling import _getRootInterface + from modules.datamodels.datamodelBilling import BillingSettings, BillingModelEnum + + billingInterface = _getRootInterface() + + # Check if settings already exist + existingSettings = billingInterface.getSettings(mandateId) + if existingSettings: + logger.info("Billing settings for root mandate already exist") + return + + # Create billing settings for root mandate + settings = BillingSettings( + mandateId=mandateId, + billingModel=BillingModelEnum.PREPAY_USER, + defaultUserCredit=10.0, # 10 CHF initial credit per user + warningThresholdPercent=10.0, + blockOnZeroBalance=True, + notifyOnWarning=True + ) + + billingInterface.createSettings(settings) + logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + + except Exception as e: + # Don't fail bootstrap if billing init fails + logger.warning(f"Failed to initialize root mandate billing (non-critical): {e}") def assignInitialUserMemberships( diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py new file mode 100644 index 00000000..ebbe3ae5 --- /dev/null +++ b/modules/interfaces/interfaceDbBilling.py @@ -0,0 +1,734 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Interface for Billing operations. +Manages billing accounts, transactions, and usage statistics. + +All billing data is stored in the poweron_billing database. +""" + +import logging +from typing import Dict, Any, List, Optional +from datetime import date, datetime, timedelta +import uuid + +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.shared.configuration import APP_CONFIG +from modules.shared.timeUtils import getUtcTimestamp +from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelBilling import ( + BillingAccount, + BillingTransaction, + BillingSettings, + UsageStatistics, + BillingAddress, + BillingModelEnum, + AccountTypeEnum, + TransactionTypeEnum, + ReferenceTypeEnum, + PeriodTypeEnum, + BillingBalanceResponse, + BillingCheckResult, +) + +logger = logging.getLogger(__name__) + +# Singleton factory for BillingObjects instances +_billingInterfaces: Dict[str, "BillingObjects"] = {} + +# Database name for billing +BILLING_DATABASE = "poweron_billing" + + +def getInterface(currentUser: User, mandateId: str = None) -> "BillingObjects": + """ + Factory function to get or create a BillingObjects instance. + + Args: + currentUser: Current user object + mandateId: Mandate ID for context + + Returns: + BillingObjects instance + """ + cacheKey = f"{currentUser.id}_{mandateId}" + + if cacheKey not in _billingInterfaces: + _billingInterfaces[cacheKey] = BillingObjects(currentUser, mandateId) + else: + _billingInterfaces[cacheKey].setUserContext(currentUser, mandateId) + + return _billingInterfaces[cacheKey] + + +def _getRootInterface() -> "BillingObjects": + """Get interface with system access for bootstrap operations.""" + from modules.security.rootAccess import getRootUser + rootUser = getRootUser() + return BillingObjects(rootUser, mandateId=None) + + +class BillingObjects: + """ + Interface for billing operations. + Manages accounts, transactions, settings, and statistics. + """ + + def __init__(self, currentUser: Optional[User] = None, mandateId: str = None): + """ + Initialize the billing interface. + + Args: + currentUser: Current user object + mandateId: Mandate ID for context + """ + self.currentUser = currentUser + self.userId = currentUser.id if currentUser else None + self.mandateId = mandateId + + # Initialize database connection + self._initializeDatabase() + + def _initializeDatabase(self): + """Initialize database connection.""" + self.db = DatabaseConnector( + databaseName=BILLING_DATABASE, + host=APP_CONFIG.get('Database_Host', 'localhost'), + port=int(APP_CONFIG.get('Database_Port', '5432')), + user=APP_CONFIG.get('Database_User', 'admin'), + password=APP_CONFIG.get('Database_Password', 'admin') + ) + + def setUserContext(self, currentUser: User, mandateId: str = None): + """ + Update user context. + + Args: + currentUser: Current user object + mandateId: Mandate ID for context + """ + self.currentUser = currentUser + self.userId = currentUser.id if currentUser else None + self.mandateId = mandateId + + # ========================================================================= + # BillingSettings Operations + # ========================================================================= + + def getSettings(self, mandateId: str) -> Optional[Dict[str, Any]]: + """ + Get billing settings for a mandate. + + Args: + mandateId: Mandate ID + + Returns: + BillingSettings dict or None if not found + """ + try: + results = self.db.getRecordset( + BillingSettings, + filterDict={"mandateId": mandateId} + ) + return results[0] if results else None + except Exception as e: + logger.error(f"Error getting billing settings: {e}") + return None + + def createSettings(self, settings: BillingSettings) -> Dict[str, Any]: + """ + Create billing settings for a mandate. + + Args: + settings: BillingSettings object + + Returns: + Created settings dict + """ + settingsDict = settings.model_dump(exclude_none=True) + + # Handle nested BillingAddress + if settings.billingAddress: + settingsDict["billingAddress"] = settings.billingAddress.model_dump() + + return self.db.recordCreate(BillingSettings, settingsDict) + + def updateSettings(self, settingsId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Update billing settings. + + Args: + settingsId: Settings ID + updates: Fields to update + + Returns: + Updated settings dict or None + """ + return self.db.recordModify(BillingSettings, settingsId, updates) + + def getOrCreateSettings(self, mandateId: str, defaultModel: BillingModelEnum = BillingModelEnum.UNLIMITED) -> Dict[str, Any]: + """ + Get or create billing settings for a mandate. + + Args: + mandateId: Mandate ID + defaultModel: Default billing model if creating + + Returns: + BillingSettings dict + """ + existing = self.getSettings(mandateId) + if existing: + return existing + + settings = BillingSettings( + mandateId=mandateId, + billingModel=defaultModel, + defaultUserCredit=10.0, + warningThresholdPercent=10.0, + blockOnZeroBalance=True, + notifyOnWarning=True + ) + return self.createSettings(settings) + + # ========================================================================= + # BillingAccount Operations + # ========================================================================= + + def getAccount(self, accountId: str) -> Optional[Dict[str, Any]]: + """Get a billing account by ID.""" + try: + results = self.db.getRecordset( + BillingAccount, + filterDict={"id": accountId} + ) + return results[0] if results else None + except Exception as e: + logger.error(f"Error getting billing account: {e}") + return None + + def getMandateAccount(self, mandateId: str) -> Optional[Dict[str, Any]]: + """ + Get the mandate-level billing account. + + Args: + mandateId: Mandate ID + + Returns: + BillingAccount dict or None + """ + try: + results = self.db.getRecordset( + BillingAccount, + filterDict={ + "mandateId": mandateId, + "accountType": AccountTypeEnum.MANDATE.value + } + ) + return results[0] if results else None + except Exception as e: + logger.error(f"Error getting mandate account: {e}") + return None + + def getUserAccount(self, mandateId: str, userId: str) -> Optional[Dict[str, Any]]: + """ + Get a user-level billing account within a mandate. + + Args: + mandateId: Mandate ID + userId: User ID + + Returns: + BillingAccount dict or None + """ + try: + results = self.db.getRecordset( + BillingAccount, + filterDict={ + "mandateId": mandateId, + "userId": userId, + "accountType": AccountTypeEnum.USER.value + } + ) + return results[0] if results else None + except Exception as e: + logger.error(f"Error getting user account: {e}") + return None + + def createAccount(self, account: BillingAccount) -> Dict[str, Any]: + """ + Create a new billing account. + + Args: + account: BillingAccount object + + Returns: + Created account dict + """ + accountDict = account.model_dump(exclude_none=True) + return self.db.recordCreate(BillingAccount, accountDict) + + def updateAccountBalance(self, accountId: str, newBalance: float) -> Optional[Dict[str, Any]]: + """ + Update account balance atomically. + + Args: + accountId: Account ID + newBalance: New balance value + + Returns: + Updated account dict or None + """ + return self.db.recordModify(BillingAccount, accountId, {"balance": newBalance}) + + def getOrCreateMandateAccount(self, mandateId: str, initialBalance: float = 0.0) -> Dict[str, Any]: + """ + Get or create a mandate-level billing account. + + Args: + mandateId: Mandate ID + initialBalance: Initial balance if creating + + Returns: + BillingAccount dict + """ + existing = self.getMandateAccount(mandateId) + if existing: + return existing + + account = BillingAccount( + mandateId=mandateId, + accountType=AccountTypeEnum.MANDATE, + balance=initialBalance, + enabled=True + ) + return self.createAccount(account) + + def getOrCreateUserAccount(self, mandateId: str, userId: str, initialBalance: float = 0.0) -> Dict[str, Any]: + """ + Get or create a user-level billing account. + + Args: + mandateId: Mandate ID + userId: User ID + initialBalance: Initial balance if creating + + Returns: + BillingAccount dict + """ + existing = self.getUserAccount(mandateId, userId) + if existing: + return existing + + account = BillingAccount( + mandateId=mandateId, + userId=userId, + accountType=AccountTypeEnum.USER, + balance=initialBalance, + enabled=True + ) + created = self.createAccount(account) + + # If initial balance > 0, create a SYSTEM credit transaction + if initialBalance > 0: + self.createTransaction(BillingTransaction( + accountId=created["id"], + transactionType=TransactionTypeEnum.CREDIT, + amount=initialBalance, + description="Initial credit for new user", + referenceType=ReferenceTypeEnum.SYSTEM + )) + + return created + + # ========================================================================= + # BillingTransaction Operations + # ========================================================================= + + def createTransaction(self, transaction: BillingTransaction) -> Dict[str, Any]: + """ + Create a new billing transaction and update account balance. + + Args: + transaction: BillingTransaction object + + Returns: + Created transaction dict + """ + # Get current account + account = self.getAccount(transaction.accountId) + if not account: + raise ValueError(f"Account {transaction.accountId} not found") + + currentBalance = account.get("balance", 0.0) + + # Calculate new balance + if transaction.transactionType == TransactionTypeEnum.CREDIT: + newBalance = currentBalance + transaction.amount + elif transaction.transactionType == TransactionTypeEnum.DEBIT: + newBalance = currentBalance - transaction.amount + else: # ADJUSTMENT + newBalance = currentBalance + transaction.amount # Can be positive or negative + + # Create transaction + transactionDict = transaction.model_dump(exclude_none=True) + created = self.db.recordCreate(BillingTransaction, transactionDict) + + # Update account balance + self.updateAccountBalance(transaction.accountId, newBalance) + + logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, " + f"balance: {currentBalance} -> {newBalance}") + + return created + + def getTransactions( + self, + accountId: str, + limit: int = 100, + offset: int = 0, + startDate: date = None, + endDate: date = None + ) -> List[Dict[str, Any]]: + """ + Get transactions for an account. + + Args: + accountId: Account ID + limit: Maximum number of results + offset: Offset for pagination + startDate: Filter by start date + endDate: Filter by end date + + Returns: + List of transaction dicts + """ + try: + filterDict = {"accountId": accountId} + results = self.db.getRecordset(BillingTransaction, filterDict=filterDict) + + # Apply date filters if provided + if startDate or endDate: + filtered = [] + for t in results: + createdAt = t.get("_createdAt") + if createdAt: + tDate = createdAt.date() if isinstance(createdAt, datetime) else createdAt + if startDate and tDate < startDate: + continue + if endDate and tDate > endDate: + continue + filtered.append(t) + results = filtered + + # Sort by creation date descending + results.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) + + # Apply pagination + return results[offset:offset + limit] + except Exception as e: + logger.error(f"Error getting transactions: {e}") + return [] + + def getTransactionsByMandate(self, mandateId: str, limit: int = 100) -> List[Dict[str, Any]]: + """ + Get all transactions for a mandate (across all accounts). + + Args: + mandateId: Mandate ID + limit: Maximum number of results + + Returns: + List of transaction dicts + """ + # Get all accounts for mandate + accounts = self.db.getRecordset(BillingAccount, filterDict={"mandateId": mandateId}) + + allTransactions = [] + for account in accounts: + transactions = self.getTransactions(account["id"], limit=limit) + allTransactions.extend(transactions) + + # Sort by creation date descending and limit + allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) + return allTransactions[:limit] + + # ========================================================================= + # Balance Check Operations + # ========================================================================= + + def checkBalance(self, mandateId: str, userId: str, estimatedCost: float) -> BillingCheckResult: + """ + Check if there's sufficient balance for an operation. + + Args: + mandateId: Mandate ID + userId: User ID + estimatedCost: Estimated cost of the operation + + Returns: + BillingCheckResult + """ + settings = self.getSettings(mandateId) + if not settings: + # No settings = no billing = allowed + return BillingCheckResult(allowed=True, billingModel=BillingModelEnum.UNLIMITED) + + billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + + # UNLIMITED = always allowed + if billingModel == BillingModelEnum.UNLIMITED: + return BillingCheckResult(allowed=True, billingModel=billingModel) + + # Get the relevant account + if billingModel == BillingModelEnum.PREPAY_USER: + account = self.getUserAccount(mandateId, userId) + else: + account = self.getMandateAccount(mandateId) + + if not account: + # No account = no balance = potentially blocked + if settings.get("blockOnZeroBalance", True): + return BillingCheckResult( + allowed=False, + reason="NO_ACCOUNT", + currentBalance=0.0, + requiredAmount=estimatedCost, + billingModel=billingModel + ) + return BillingCheckResult(allowed=True, currentBalance=0.0, billingModel=billingModel) + + currentBalance = account.get("balance", 0.0) + + # CREDIT_POSTPAY with credit limit check + if billingModel == BillingModelEnum.CREDIT_POSTPAY: + creditLimit = account.get("creditLimit") + if creditLimit and abs(currentBalance) + estimatedCost > creditLimit: + return BillingCheckResult( + allowed=False, + reason="CREDIT_LIMIT_EXCEEDED", + currentBalance=currentBalance, + requiredAmount=estimatedCost, + billingModel=billingModel + ) + return BillingCheckResult(allowed=True, currentBalance=currentBalance, billingModel=billingModel) + + # PREPAY models - check balance + if currentBalance < estimatedCost: + if settings.get("blockOnZeroBalance", True): + return BillingCheckResult( + allowed=False, + reason="INSUFFICIENT_BALANCE", + currentBalance=currentBalance, + requiredAmount=estimatedCost, + billingModel=billingModel + ) + + return BillingCheckResult(allowed=True, currentBalance=currentBalance, billingModel=billingModel) + + def recordUsage( + self, + mandateId: str, + userId: str, + priceCHF: float, + workflowId: str = None, + featureInstanceId: str = None, + featureCode: str = None, + aicoreProvider: str = None, + description: str = "AI Usage" + ) -> Optional[Dict[str, Any]]: + """ + Record usage cost as a billing transaction. + + Args: + mandateId: Mandate ID + userId: User ID + priceCHF: Cost in CHF + workflowId: Optional workflow ID + featureInstanceId: Optional feature instance ID + featureCode: Optional feature code + aicoreProvider: Optional AICore provider name + description: Transaction description + + Returns: + Created transaction dict or None + """ + if priceCHF <= 0: + return None + + settings = self.getSettings(mandateId) + if not settings: + logger.debug(f"No billing settings for mandate {mandateId}, skipping usage recording") + return None + + billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + + # UNLIMITED = no transaction recording + if billingModel == BillingModelEnum.UNLIMITED: + return None + + # Get or create the relevant account + if billingModel == BillingModelEnum.PREPAY_USER: + account = self.getOrCreateUserAccount(mandateId, userId) + else: + account = self.getOrCreateMandateAccount(mandateId) + + # Create debit transaction + transaction = BillingTransaction( + accountId=account["id"], + transactionType=TransactionTypeEnum.DEBIT, + amount=priceCHF, + description=description, + referenceType=ReferenceTypeEnum.WORKFLOW, + workflowId=workflowId, + featureInstanceId=featureInstanceId, + featureCode=featureCode, + aicoreProvider=aicoreProvider + ) + + return self.createTransaction(transaction) + + # ========================================================================= + # Statistics Operations + # ========================================================================= + + def getUsageStatistics( + self, + accountId: str, + periodType: PeriodTypeEnum, + year: int, + month: int = None + ) -> List[Dict[str, Any]]: + """ + Get usage statistics for an account. + + Args: + accountId: Account ID + periodType: Period type (DAY, MONTH, YEAR) + year: Year + month: Month (for DAY period type) + + Returns: + List of statistics dicts + """ + filterDict = { + "accountId": accountId, + "periodType": periodType.value + } + + results = self.db.getRecordset(UsageStatistics, filterDict=filterDict) + + # Filter by year + filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year] + + # Filter by month if specified + if month and periodType == PeriodTypeEnum.DAY: + filtered = [s for s in filtered if s["periodStart"].month == month] + + return sorted(filtered, key=lambda x: x.get("periodStart", date.min)) + + def calculateStatisticsFromTransactions( + self, + accountId: str, + startDate: date, + endDate: date + ) -> Dict[str, Any]: + """ + Calculate statistics from transactions for a period. + + Args: + accountId: Account ID + startDate: Start date + endDate: End date + + Returns: + Statistics dict + """ + transactions = self.getTransactions(accountId, limit=10000, startDate=startDate, endDate=endDate) + + # Filter only DEBIT transactions (usage) + debits = [t for t in transactions if t.get("transactionType") == TransactionTypeEnum.DEBIT.value] + + totalCost = sum(t.get("amount", 0) for t in debits) + + # Calculate by provider + costByProvider = {} + for t in debits: + provider = t.get("aicoreProvider", "unknown") + costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0) + + # Calculate by feature + costByFeature = {} + for t in debits: + feature = t.get("featureCode", "unknown") + costByFeature[feature] = costByFeature.get(feature, 0) + t.get("amount", 0) + + return { + "totalCostCHF": totalCost, + "transactionCount": len(debits), + "costByProvider": costByProvider, + "costByFeature": costByFeature + } + + # ========================================================================= + # Utility Methods + # ========================================================================= + + def getBalancesForUser(self, userId: str) -> List[BillingBalanceResponse]: + """ + Get all billing balances for a user across mandates. + + Args: + userId: User ID + + Returns: + List of BillingBalanceResponse + """ + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + balances = [] + + # Get all mandates the user belongs to + try: + appInterface = getAppInterface(self.currentUser) + userMandates = appInterface.getUserMandates(userId) + + for um in userMandates: + mandateId = um.get("mandateId") + mandate = appInterface.getMandate(mandateId) + if not mandate: + continue + + settings = self.getSettings(mandateId) + if not settings: + continue + + billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + + # Get the relevant account + if billingModel == BillingModelEnum.PREPAY_USER: + account = self.getUserAccount(mandateId, userId) + elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]: + account = self.getMandateAccount(mandateId) + else: + continue + + if not account: + continue + + balance = account.get("balance", 0.0) + warningThreshold = account.get("warningThreshold", 0.0) + + balances.append(BillingBalanceResponse( + mandateId=mandateId, + mandateName=mandate.get("name", ""), + billingModel=billingModel, + balance=balance, + warningThreshold=warningThreshold, + isWarning=balance <= warningThreshold, + creditLimit=account.get("creditLimit") + )) + except Exception as e: + logger.error(f"Error getting balances for user: {e}") + + return balances diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py new file mode 100644 index 00000000..c191e793 --- /dev/null +++ b/modules/routes/routeBilling.py @@ -0,0 +1,557 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Billing routes for the backend API. +Implements the endpoints for billing management and usage tracking. + +Features: +- User endpoints: View balance, transactions, statistics +- Admin endpoints: Manage settings, add credits, view all accounts +""" + +from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query +from typing import List, Dict, Any, Optional +from fastapi import status +import logging +from datetime import date, datetime +from pydantic import BaseModel, Field + +# Import auth module +from modules.auth import limiter, requireSysAdmin, getRequestContext, RequestContext + +# Import billing components +from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface +from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService +from modules.datamodels.datamodelBilling import ( + BillingAccount, + BillingTransaction, + BillingSettings, + BillingAddress, + BillingModelEnum, + TransactionTypeEnum, + ReferenceTypeEnum, + PeriodTypeEnum, + BillingBalanceResponse, + BillingStatisticsResponse, + BillingStatisticsChartData, + BillingCheckResult, +) + +# Configure logger +logger = logging.getLogger(__name__) + +# ============================================================================= +# Request/Response Models +# ============================================================================= + +class CreditAddRequest(BaseModel): + """Request model for adding credit to an account.""" + userId: Optional[str] = Field(None, description="Target user ID (for PREPAY_USER model)") + amount: float = Field(..., gt=0, description="Amount to credit in CHF") + description: str = Field(default="Manual credit", description="Transaction description") + + +class BillingSettingsUpdate(BaseModel): + """Request model for updating billing settings.""" + billingModel: Optional[BillingModelEnum] = None + defaultUserCredit: Optional[float] = Field(None, ge=0) + warningThresholdPercent: Optional[float] = Field(None, ge=0, le=100) + blockOnZeroBalance: Optional[bool] = None + notifyOnWarning: Optional[bool] = None + notifyEmails: Optional[List[str]] = None + billingAddress: Optional[BillingAddress] = None + + +class TransactionResponse(BaseModel): + """Response model for a billing transaction.""" + id: str + accountId: str + transactionType: TransactionTypeEnum + amount: float + description: str + referenceType: Optional[ReferenceTypeEnum] + workflowId: Optional[str] + featureCode: Optional[str] + aicoreProvider: Optional[str] + createdAt: Optional[datetime] + + +class AccountSummary(BaseModel): + """Summary of a billing account.""" + id: str + mandateId: str + userId: Optional[str] + accountType: str + balance: float + creditLimit: Optional[float] + warningThreshold: float + enabled: bool + + +class UsageReportResponse(BaseModel): + """Usage report for a period.""" + period: str + totalCost: float + transactionCount: int + costByProvider: Dict[str, float] + costByFeature: Dict[str, float] + + +# ============================================================================= +# Router Setup +# ============================================================================= + +router = APIRouter( + prefix="/api/billing", + tags=["Billing"], + responses={404: {"description": "Not found"}} +) + +# ============================================================================= +# User Endpoints +# ============================================================================= + +@router.get("/balance", response_model=List[BillingBalanceResponse]) +@limiter.limit("60/minute") +async def getBalance( + request: Request, + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get billing balances for all mandates the current user belongs to. + Returns balance information for each mandate. + """ + try: + billingService = getBillingService( + ctx.currentUser, + ctx.mandateId, + featureCode="billing" + ) + + balances = billingService.getBalancesForUser() + return balances + + except Exception as e: + logger.error(f"Error getting billing balance: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/balance/{mandateId}", response_model=BillingBalanceResponse) +@limiter.limit("60/minute") +async def getBalanceForMandate( + request: Request, + mandateId: str = Path(..., description="Mandate ID"), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get billing balance for a specific mandate. + """ + try: + billingService = getBillingService( + ctx.currentUser, + mandateId, + featureCode="billing" + ) + + # Check balance + checkResult = billingService.checkBalance(0.0) + + # Get mandate name from app interface + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + appInterface = getAppInterface(ctx.currentUser, mandateId=mandateId) + mandate = appInterface.getMandate(mandateId) + mandateName = mandate.get("name", "") if mandate else "" + + return BillingBalanceResponse( + mandateId=mandateId, + mandateName=mandateName, + billingModel=checkResult.billingModel or BillingModelEnum.UNLIMITED, + balance=checkResult.currentBalance or 0.0, + warningThreshold=0.0, # TODO: Get from account + isWarning=False, + creditLimit=None + ) + + except Exception as e: + logger.error(f"Error getting billing balance for mandate {mandateId}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/transactions", response_model=List[TransactionResponse]) +@limiter.limit("30/minute") +async def getTransactions( + request: Request, + limit: int = Query(default=50, ge=1, le=500), + offset: int = Query(default=0, ge=0), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get transaction history for the current mandate. + """ + try: + billingService = getBillingService( + ctx.currentUser, + ctx.mandateId, + featureCode="billing" + ) + + transactions = billingService.getTransactionHistory(limit=limit) + + # Convert to response model + result = [] + for t in transactions[offset:offset + limit]: + result.append(TransactionResponse( + id=t.get("id"), + accountId=t.get("accountId"), + transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), + amount=t.get("amount", 0.0), + description=t.get("description", ""), + referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, + workflowId=t.get("workflowId"), + featureCode=t.get("featureCode"), + aicoreProvider=t.get("aicoreProvider"), + createdAt=t.get("_createdAt") + )) + + return result + + except Exception as e: + logger.error(f"Error getting billing transactions: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/statistics/{period}", response_model=UsageReportResponse) +@limiter.limit("30/minute") +async def getStatistics( + request: Request, + period: str = Path(..., description="Period: 'day', 'month', or 'year'"), + year: int = Query(..., description="Year"), + month: Optional[int] = Query(None, description="Month (1-12, required for 'day' period)"), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get usage statistics for a period. + """ + try: + # Validate period + if period not in ["day", "month", "year"]: + raise HTTPException(status_code=400, detail="Invalid period. Use 'day', 'month', or 'year'") + + if period == "day" and not month: + raise HTTPException(status_code=400, detail="Month is required for 'day' period") + + billingInterface = getBillingInterface(ctx.currentUser, ctx.mandateId) + settings = billingInterface.getSettings(ctx.mandateId) + + if not settings: + return UsageReportResponse( + period=period, + totalCost=0.0, + transactionCount=0, + costByProvider={}, + costByFeature={} + ) + + billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + + # Get the relevant account + if billingModel == BillingModelEnum.PREPAY_USER: + account = billingInterface.getUserAccount(ctx.mandateId, ctx.currentUser.id) + else: + account = billingInterface.getMandateAccount(ctx.mandateId) + + if not account: + return UsageReportResponse( + period=period, + totalCost=0.0, + transactionCount=0, + costByProvider={}, + costByFeature={} + ) + + # Calculate date range + if period == "day": + startDate = date(year, month, 1) + if month == 12: + endDate = date(year + 1, 1, 1) + else: + endDate = date(year, month + 1, 1) + elif period == "month": + startDate = date(year, 1, 1) + endDate = date(year + 1, 1, 1) + else: # year + startDate = date(year, 1, 1) + endDate = date(year + 1, 1, 1) + + # Get statistics from transactions + stats = billingInterface.calculateStatisticsFromTransactions( + account["id"], + startDate, + endDate + ) + + return UsageReportResponse( + period=period, + totalCost=stats.get("totalCostCHF", 0.0), + transactionCount=stats.get("transactionCount", 0), + costByProvider=stats.get("costByProvider", {}), + costByFeature=stats.get("costByFeature", {}) + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting billing statistics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/providers", response_model=List[str]) +@limiter.limit("60/minute") +async def getAllowedProviders( + request: Request, + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get list of AICore providers the current user is allowed to use. + """ + try: + billingService = getBillingService( + ctx.currentUser, + ctx.mandateId, + featureCode="billing" + ) + + return billingService.getallowedProviders() + + except Exception as e: + logger.error(f"Error getting allowed providers: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Admin Endpoints +# ============================================================================= + +@router.get("/admin/settings/{mandateId}", response_model=Dict[str, Any]) +@limiter.limit("30/minute") +@requireSysAdmin +async def getSettingsAdmin( + request: Request, + mandateId: str = Path(..., description="Mandate ID"), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get billing settings for a mandate (SysAdmin only). + """ + try: + billingInterface = getBillingInterface(ctx.currentUser, mandateId) + settings = billingInterface.getSettings(mandateId) + + if not settings: + raise HTTPException(status_code=404, detail="Billing settings not found") + + return settings + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting billing settings: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/settings/{mandateId}", response_model=Dict[str, Any]) +@limiter.limit("10/minute") +@requireSysAdmin +async def createOrUpdateSettings( + request: Request, + mandateId: str = Path(..., description="Mandate ID"), + settingsUpdate: BillingSettingsUpdate = Body(...), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Create or update billing settings for a mandate (SysAdmin only). + """ + try: + billingInterface = getBillingInterface(ctx.currentUser, mandateId) + existingSettings = billingInterface.getSettings(mandateId) + + if existingSettings: + # Update existing settings + updates = settingsUpdate.model_dump(exclude_none=True) + if updates: + result = billingInterface.updateSettings(existingSettings["id"], updates) + return result or existingSettings + return existingSettings + else: + # Create new settings + from modules.datamodels.datamodelBilling import BillingSettings + + newSettings = BillingSettings( + mandateId=mandateId, + billingModel=settingsUpdate.billingModel or BillingModelEnum.UNLIMITED, + defaultUserCredit=settingsUpdate.defaultUserCredit or 10.0, + warningThresholdPercent=settingsUpdate.warningThresholdPercent or 10.0, + blockOnZeroBalance=settingsUpdate.blockOnZeroBalance if settingsUpdate.blockOnZeroBalance is not None else True, + notifyOnWarning=settingsUpdate.notifyOnWarning if settingsUpdate.notifyOnWarning is not None else True, + notifyEmails=settingsUpdate.notifyEmails or [], + billingAddress=settingsUpdate.billingAddress + ) + + return billingInterface.createSettings(newSettings) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating billing settings: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/credit/{mandateId}", response_model=Dict[str, Any]) +@limiter.limit("10/minute") +@requireSysAdmin +async def addCredit( + request: Request, + mandateId: str = Path(..., description="Mandate ID"), + creditRequest: CreditAddRequest = Body(...), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Add credit to a billing account (SysAdmin only). + For PREPAY_USER model, specify userId. For PREPAY_MANDATE, leave userId empty. + """ + try: + # Get settings to determine billing model + billingInterface = getBillingInterface(ctx.currentUser, mandateId) + settings = billingInterface.getSettings(mandateId) + + if not settings: + raise HTTPException(status_code=404, detail="Billing settings not found for this mandate") + + billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + + # Validate request based on billing model + if billingModel == BillingModelEnum.PREPAY_USER: + if not creditRequest.userId: + raise HTTPException(status_code=400, detail="userId is required for PREPAY_USER model") + + # Create user-level account if needed and add credit + account = billingInterface.getOrCreateUserAccount( + mandateId, + creditRequest.userId, + initialBalance=0.0 + ) + elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]: + # Create mandate-level account if needed and add credit + account = billingInterface.getOrCreateMandateAccount(mandateId, initialBalance=0.0) + else: + raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model") + + # Create credit transaction + from modules.datamodels.datamodelBilling import BillingTransaction + + transaction = BillingTransaction( + accountId=account["id"], + transactionType=TransactionTypeEnum.CREDIT, + amount=creditRequest.amount, + description=creditRequest.description, + referenceType=ReferenceTypeEnum.ADMIN + ) + + result = billingInterface.createTransaction(transaction) + + logger.info(f"Added {creditRequest.amount} CHF credit to account {account['id']} in mandate {mandateId}") + + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error adding credit: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/admin/accounts/{mandateId}", response_model=List[AccountSummary]) +@limiter.limit("30/minute") +@requireSysAdmin +async def getAccounts( + request: Request, + mandateId: str = Path(..., description="Mandate ID"), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get all billing accounts for a mandate (SysAdmin only). + """ + try: + billingInterface = getBillingInterface(ctx.currentUser, mandateId) + + # Get all accounts for this mandate + from modules.connectors.connectorDbPostgre import DatabaseConnector + from modules.shared.configuration import APP_CONFIG + from modules.datamodels.datamodelBilling import BillingAccount + + db = DatabaseConnector( + databaseName="poweron_billing", + host=APP_CONFIG.get('Database_Host', 'localhost'), + port=int(APP_CONFIG.get('Database_Port', '5432')), + user=APP_CONFIG.get('Database_User', 'admin'), + password=APP_CONFIG.get('Database_Password', 'admin') + ) + + accounts = db.getRecordset(BillingAccount, filterDict={"mandateId": mandateId}) + + result = [] + for acc in accounts: + result.append(AccountSummary( + id=acc.get("id"), + mandateId=acc.get("mandateId"), + userId=acc.get("userId"), + accountType=acc.get("accountType"), + balance=acc.get("balance", 0.0), + creditLimit=acc.get("creditLimit"), + warningThreshold=acc.get("warningThreshold", 0.0), + enabled=acc.get("enabled", True) + )) + + return result + + except Exception as e: + logger.error(f"Error getting billing accounts: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/admin/transactions/{mandateId}", response_model=List[TransactionResponse]) +@limiter.limit("30/minute") +@requireSysAdmin +async def getTransactionsAdmin( + request: Request, + mandateId: str = Path(..., description="Mandate ID"), + limit: int = Query(default=100, ge=1, le=1000), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get all transactions for a mandate (SysAdmin only). + """ + try: + billingInterface = getBillingInterface(ctx.currentUser, mandateId) + transactions = billingInterface.getTransactionsByMandate(mandateId, limit=limit) + + result = [] + for t in transactions: + result.append(TransactionResponse( + id=t.get("id"), + accountId=t.get("accountId"), + transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), + amount=t.get("amount", 0.0), + description=t.get("description", ""), + referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, + workflowId=t.get("workflowId"), + featureCode=t.get("featureCode"), + aicoreProvider=t.get("aicoreProvider"), + createdAt=t.get("_createdAt") + )) + + return result + + except Exception as e: + logger.error(f"Error getting billing transactions for mandate {mandateId}: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/services/serviceBilling/__init__.py b/modules/services/serviceBilling/__init__.py new file mode 100644 index 00000000..ab0805d5 --- /dev/null +++ b/modules/services/serviceBilling/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Billing service module.""" + +from .mainServiceBilling import BillingService, getService + +__all__ = ["BillingService", "getService"] diff --git a/modules/services/serviceBilling/mainServiceBilling.py b/modules/services/serviceBilling/mainServiceBilling.py new file mode 100644 index 00000000..709ab61a --- /dev/null +++ b/modules/services/serviceBilling/mainServiceBilling.py @@ -0,0 +1,408 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Billing Service - Central service for billing operations. + +Handles: +- Balance checks before AI operations +- Cost recording after AI operations +- Provider permission checks via RBAC +- Price calculation with markup +""" + +import logging +from typing import Dict, Any, List, Optional +from datetime import datetime + +from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelBilling import ( + BillingModelEnum, + BillingCheckResult, + TransactionTypeEnum, + ReferenceTypeEnum, + BillingTransaction, + BillingBalanceResponse, +) +from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface + +logger = logging.getLogger(__name__) + +# Markup percentage for internal pricing (50% = 1.5x) +BILLING_MARKUP_PERCENT = 50 + +# Singleton cache +_billingServices: Dict[str, "BillingService"] = {} + + +def getService(currentUser: User, mandateId: str, featureInstanceId: str = None, featureCode: str = None) -> "BillingService": + """ + Factory function to get or create a BillingService instance. + + Args: + currentUser: Current user object + mandateId: Mandate ID for context + featureInstanceId: Optional feature instance ID + featureCode: Optional feature code (e.g., 'chatplayground', 'automation') + + Returns: + BillingService instance + """ + cacheKey = f"{currentUser.id}_{mandateId}_{featureInstanceId}" + + if cacheKey not in _billingServices: + _billingServices[cacheKey] = BillingService(currentUser, mandateId, featureInstanceId, featureCode) + else: + _billingServices[cacheKey].setContext(currentUser, mandateId, featureInstanceId, featureCode) + + return _billingServices[cacheKey] + + +class BillingService: + """ + Central billing service for AI operations. + + Responsibilities: + - Check balance before operations + - Record usage costs + - Apply pricing markup + - Check provider permissions via RBAC + """ + + def __init__( + self, + currentUser: User, + mandateId: str, + featureInstanceId: str = None, + featureCode: str = None + ): + """ + Initialize the billing service. + + Args: + currentUser: Current user object + mandateId: Mandate ID + featureInstanceId: Optional feature instance ID + featureCode: Optional feature code + """ + self.currentUser = currentUser + self.mandateId = mandateId + self.featureInstanceId = featureInstanceId + self.featureCode = featureCode + + # Get billing interface + self._billingInterface = getBillingInterface(currentUser, mandateId) + + # Cache settings + self._settingsCache = None + + def setContext( + self, + currentUser: User, + mandateId: str, + featureInstanceId: str = None, + featureCode: str = None + ): + """Update service context.""" + self.currentUser = currentUser + self.mandateId = mandateId + self.featureInstanceId = featureInstanceId + self.featureCode = featureCode + self._billingInterface = getBillingInterface(currentUser, mandateId) + self._settingsCache = None + + def _getSettings(self) -> Optional[Dict[str, Any]]: + """Get billing settings with caching.""" + if self._settingsCache is None: + self._settingsCache = self._billingInterface.getSettings(self.mandateId) + return self._settingsCache + + # ========================================================================= + # Price Calculation + # ========================================================================= + + def calculatePriceWithMarkup(self, basePriceCHF: float) -> float: + """ + Calculate final price with markup. + + The AICore plugins return prices in their original currency (USD). + This method applies the configured markup percentage. + + Args: + basePriceCHF: Base price from AI model (actually USD from provider) + + Returns: + Final price in CHF with markup applied + """ + if basePriceCHF <= 0: + return 0.0 + + # Apply markup (50% = multiply by 1.5) + markup_multiplier = 1 + (BILLING_MARKUP_PERCENT / 100) + return round(basePriceCHF * markup_multiplier, 6) + + # ========================================================================= + # Balance Operations + # ========================================================================= + + def checkBalance(self, estimatedCost: float = 0.0) -> BillingCheckResult: + """ + Check if the current user/mandate has sufficient balance. + + Args: + estimatedCost: Estimated cost of the operation (with markup applied) + + Returns: + BillingCheckResult indicating if operation is allowed + """ + return self._billingInterface.checkBalance( + self.mandateId, + self.currentUser.id, + estimatedCost + ) + + def hasBalance(self, estimatedCost: float = 0.0) -> bool: + """ + Quick check if balance is sufficient. + + Args: + estimatedCost: Estimated cost with markup + + Returns: + True if operation is allowed + """ + result = self.checkBalance(estimatedCost) + return result.allowed + + def getCurrentBalance(self) -> float: + """ + Get current balance for the user/mandate. + + Returns: + Current balance in CHF + """ + result = self.checkBalance(0.0) + return result.currentBalance or 0.0 + + # ========================================================================= + # Usage Recording + # ========================================================================= + + def recordUsage( + self, + priceCHF: float, + workflowId: str = None, + aicoreProvider: str = None, + description: str = None + ) -> Optional[Dict[str, Any]]: + """ + Record AI usage cost as a billing transaction. + + This method: + 1. Applies the pricing markup + 2. Creates a DEBIT transaction + 3. Updates the account balance + + Args: + priceCHF: Base price from AI model (before markup) + workflowId: Optional workflow ID + aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai') + description: Optional description + + Returns: + Created transaction dict or None if not recorded + """ + if priceCHF <= 0: + return None + + # Apply markup + finalPrice = self.calculatePriceWithMarkup(priceCHF) + + if finalPrice <= 0: + return None + + # Build description + if not description: + description = f"AI Usage: {aicoreProvider or 'unknown'}" + + return self._billingInterface.recordUsage( + mandateId=self.mandateId, + userId=self.currentUser.id, + priceCHF=finalPrice, + workflowId=workflowId, + featureInstanceId=self.featureInstanceId, + featureCode=self.featureCode, + aicoreProvider=aicoreProvider, + description=description + ) + + # ========================================================================= + # Provider Permission Check (via RBAC) + # ========================================================================= + + def isProviderAllowed(self, provider: str) -> bool: + """ + Check if the user has permission to use an AICore provider. + + Uses RBAC to check for resource permission: + resource.aicore.{provider} + + Args: + provider: Provider name (e.g., 'anthropic', 'openai') + + Returns: + True if provider is allowed + """ + try: + from modules.security.rbac import RbacClass + from modules.datamodels.datamodelRbac import AccessRuleContext + from modules.connectors.connectorDbPostgre import DatabaseConnector + from modules.shared.configuration import APP_CONFIG + + # Get database connectors + dbApp = DatabaseConnector( + databaseName="poweron_app", + host=APP_CONFIG.get('Database_Host', 'localhost'), + port=int(APP_CONFIG.get('Database_Port', '5432')), + user=APP_CONFIG.get('Database_User', 'admin'), + password=APP_CONFIG.get('Database_Password', 'admin') + ) + + rbac = RbacClass(dbApp, dbApp) + resourceKey = f"resource.aicore.{provider}" + + # Check if user has view permission for this resource (view = use for RESOURCE context) + permissions = rbac.getUserPermissions( + self.currentUser, + AccessRuleContext.RESOURCE, + resourceKey, + mandateId=self.mandateId + ) + + return permissions.view + except Exception as e: + logger.warning(f"Error checking provider permission: {e}") + # Default to allowed if RBAC check fails + return True + + def getallowedProviders(self) -> List[str]: + """ + Get list of AICore providers the user is allowed to use. + + Returns: + List of allowed provider names + """ + try: + from modules.aicore.aicoreModelRegistry import modelRegistry + + # Get all available providers + connectors = modelRegistry.discoverConnectors() + allProviders = [c.getConnectorType() for c in connectors] + + # Filter by RBAC permissions + return [p for p in allProviders if self.isProviderAllowed(p)] + except Exception as e: + logger.warning(f"Error getting allowed providers: {e}") + return [] + + # ========================================================================= + # Admin Operations + # ========================================================================= + + def addCredit( + self, + amount: float, + description: str = "Manual credit", + referenceType: ReferenceTypeEnum = ReferenceTypeEnum.ADMIN + ) -> Optional[Dict[str, Any]]: + """ + Add credit to the account (admin operation). + + Args: + amount: Amount to credit (positive) + description: Transaction description + referenceType: Reference type (ADMIN, PAYMENT, SYSTEM) + + Returns: + Created transaction dict or None + """ + if amount <= 0: + return None + + settings = self._getSettings() + if not settings: + logger.warning(f"No billing settings for mandate {self.mandateId}") + return None + + billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + + # Get or create account + if billingModel == BillingModelEnum.PREPAY_USER: + account = self._billingInterface.getOrCreateUserAccount( + self.mandateId, + self.currentUser.id, + initialBalance=0.0 + ) + else: + account = self._billingInterface.getOrCreateMandateAccount( + self.mandateId, + initialBalance=0.0 + ) + + # Create credit transaction + transaction = BillingTransaction( + accountId=account["id"], + transactionType=TransactionTypeEnum.CREDIT, + amount=amount, + description=description, + referenceType=referenceType + ) + + return self._billingInterface.createTransaction(transaction) + + # ========================================================================= + # Statistics & Reporting + # ========================================================================= + + def getBalancesForUser(self) -> List[BillingBalanceResponse]: + """ + Get all billing balances for the current user. + + Returns: + List of balance responses for each mandate + """ + return self._billingInterface.getBalancesForUser(self.currentUser.id) + + def getTransactionHistory(self, limit: int = 100) -> List[Dict[str, Any]]: + """ + Get transaction history for the current mandate. + + Args: + limit: Maximum number of transactions + + Returns: + List of transactions + """ + return self._billingInterface.getTransactionsByMandate(self.mandateId, limit=limit) + + +# ============================================================================ +# Exception Classes +# ============================================================================ + +class InsufficientBalanceException(Exception): + """Raised when there's insufficient balance for an operation.""" + + def __init__(self, currentBalance: float, requiredAmount: float, message: str = None): + self.currentBalance = currentBalance + self.requiredAmount = requiredAmount + self.message = message or f"Insufficient balance. Current: {currentBalance:.2f} CHF, Required: {requiredAmount:.2f} CHF" + super().__init__(self.message) + + +class ProviderNotAllowedException(Exception): + """Raised when a user doesn't have permission to use an AI provider.""" + + def __init__(self, provider: str, message: str = None): + self.provider = provider + self.message = message or f"Provider '{provider}' is not allowed for your role" + super().__init__(self.message) diff --git a/modules/services/serviceChat/mainServiceChat.py b/modules/services/serviceChat/mainServiceChat.py index b7910720..37b232b8 100644 --- a/modules/services/serviceChat/mainServiceChat.py +++ b/modules/services/serviceChat/mainServiceChat.py @@ -674,7 +674,8 @@ class ChatService: return chatLog def storeWorkflowStat(self, workflow: Any, aiResponse: Any, process: str) -> ChatStat: - """Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list.""" + """Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list. + Also records the usage cost to the billing system if configured.""" try: # Create ChatStat from AiCallResponse data statData = { @@ -696,10 +697,69 @@ class ChatService: workflow.stats = [] workflow.stats.append(stat) + # Record billing transaction if mandateId is available + self._recordBillingUsage(workflow, aiResponse, process) + return stat except Exception as e: logger.error(f"Failed to store workflow stat: {e}") raise + + def _recordBillingUsage(self, workflow: Any, aiResponse: Any, process: str) -> None: + """Record AI usage to the billing system. + + This method: + 1. Gets the mandate context from services + 2. Records the usage cost with 50% markup via BillingService + + Args: + workflow: ChatWorkflow object + aiResponse: AI call response with cost information + process: Process identifier for the AI call + """ + try: + # Check if we have mandate context + mandateId = getattr(self.services, 'mandateId', None) + if not mandateId: + logger.debug("No mandate context, skipping billing recording") + return + + # Check if there's a cost to record + priceCHF = getattr(aiResponse, 'priceCHF', 0.0) + if not priceCHF or priceCHF <= 0: + return + + # Extract provider from model name (e.g., "anthropic.claude-3-sonnet" -> "anthropic") + modelName = getattr(aiResponse, 'modelName', '') or '' + aicoreProvider = modelName.split('.')[0] if '.' in modelName else 'unknown' + + # Get feature context if available + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + + # Import and use BillingService + from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService + + billingService = getBillingService( + self.user, + mandateId, + featureInstanceId=featureInstanceId, + featureCode=featureCode + ) + + # Record the usage (includes 50% markup) + billingService.recordUsage( + priceCHF=priceCHF, + workflowId=workflow.id, + aicoreProvider=aicoreProvider, + description=f"AI Usage: {process}" + ) + + logger.debug(f"Recorded billing usage: {priceCHF} CHF for {process} (provider: {aicoreProvider})") + + except Exception as e: + # Don't fail the main operation if billing recording fails + logger.warning(f"Failed to record billing usage (non-critical): {e}") def updateMessage(self, messageId: str, messageData: Dict[str, Any]): """Update message by delegating to the chat interface""" diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py index 9b300d78..e80efbe6 100644 --- a/modules/system/mainSystem.py +++ b/modules/system/mainSystem.py @@ -91,6 +91,29 @@ NAVIGATION_SECTIONS = [ }, ], }, + { + "id": "billing", + "title": {"en": "BILLING", "de": "BILLING", "fr": "FACTURATION"}, + "order": 35, + "items": [ + { + "id": "billing-dashboard", + "objectKey": "ui.billing.dashboard", + "label": {"en": "Balance", "de": "Guthaben", "fr": "Solde"}, + "icon": "FaWallet", + "path": "/billing", + "order": 10, + }, + { + "id": "billing-transactions", + "objectKey": "ui.billing.transactions", + "label": {"en": "Transactions", "de": "Transaktionen", "fr": "Transactions"}, + "icon": "FaListAlt", + "path": "/billing/transactions", + "order": 20, + }, + ], + }, { "id": "admin", "title": {"en": "ADMINISTRATION", "de": "ADMINISTRATION", "fr": "ADMINISTRATION"}, @@ -178,6 +201,15 @@ NAVIGATION_SECTIONS = [ "order": 50, "adminOnly": True, }, + { + "id": "admin-billing", + "objectKey": "ui.admin.billing", + "label": {"en": "Billing Administration", "de": "Billing-Verwaltung", "fr": "Administration de facturation"}, + "icon": "FaMoneyBillAlt", + "path": "/admin/billing", + "order": 60, + "adminOnly": True, + }, ], }, ] From fd923b89b8e5a79d459a92959a633d938c55d6b4 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Wed, 4 Feb 2026 22:10:23 +0100 Subject: [PATCH 03/18] billing integration into ai workflow --- modules/datamodels/datamodelChat.py | 2 + modules/services/serviceAi/mainServiceAi.py | 113 +++++++++++++++++++ modules/workflows/automation/mainWorkflow.py | 22 +++- 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index 8ba3ced1..02f80762 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -399,6 +399,7 @@ class UserInputRequest(BaseModel): listFileId: List[str] = Field(default_factory=list, description="List of file IDs") userLanguage: str = Field(default="en", description="User's preferred language") workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue") + preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai')") registerModelLabels( @@ -408,6 +409,7 @@ registerModelLabels( "prompt": {"en": "Prompt", "fr": "Invite"}, "listFileId": {"en": "File IDs", "fr": "IDs des fichiers"}, "userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"}, + "preferredProvider": {"en": "Preferred Provider", "fr": "Fournisseur préféré"}, }, ) diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py index a728bafc..81d83022 100644 --- a/modules/services/serviceAi/mainServiceAi.py +++ b/modules/services/serviceAi/mainServiceAi.py @@ -18,6 +18,11 @@ from modules.shared.jsonUtils import ( ) from .subJsonResponseHandling import JsonResponseHandler from modules.datamodels.datamodelAi import JsonAccumulationState +from modules.services.serviceBilling.mainServiceBilling import ( + getService as getBillingService, + InsufficientBalanceException, + ProviderNotAllowedException +) logger = logging.getLogger(__name__) @@ -85,12 +90,120 @@ class AiService: Replaces direct calls to self.aiObjects.call() to route content parts processing through serviceExtraction layer. + + Includes billing checks: + - Balance check before AI call + - Provider permission check (via RBAC) """ + # Billing check before AI call + await self._checkBillingBeforeAiCall() + if hasattr(request, 'contentParts') and request.contentParts: return await self.extractionService.processContentPartsWithAi( request, self.aiObjects, progressCallback ) return await self.aiObjects.callWithTextContext(request) + + async def _checkBillingBeforeAiCall(self) -> None: + """ + Check billing status before making an AI call. + + Verifies: + 1. User has sufficient balance (for prepay models) + 2. Provider is allowed for the user (via RBAC) + + Raises: + InsufficientBalanceException: If balance is insufficient + ProviderNotAllowedException: If provider is not allowed + """ + try: + # Get context from services + if not self.services: + logger.debug("No service center - skipping billing check") + return + + user = getattr(self.services, 'user', None) + mandateId = getattr(self.services, 'mandateId', None) + + if not user or not mandateId: + logger.debug("No user or mandate context - skipping billing check") + return + + # Get feature context + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + + # Get billing service + billingService = getBillingService( + user, + mandateId, + featureInstanceId, + featureCode + ) + + # Check balance (estimate typical AI call cost) + # We use a small estimate here; actual cost is recorded after the call + estimatedCost = 0.01 # ~1 cent CHF minimum + balanceCheck = billingService.checkBalance(estimatedCost) + + if not balanceCheck.allowed: + logger.warning( + f"Billing check failed for user {user.id}: " + f"Balance {balanceCheck.currentBalance:.2f} CHF, " + f"Reason: {balanceCheck.reason}" + ) + raise InsufficientBalanceException( + currentBalance=balanceCheck.currentBalance or 0.0, + requiredAmount=estimatedCost, + message=f"Ungenügendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}" + ) + + logger.debug(f"Billing check passed: Balance {balanceCheck.currentBalance:.2f} CHF") + + # Check if at least one provider is allowed (RBAC check) + rbacAllowedProviders = billingService.getallowedProviders() + if not rbacAllowedProviders: + logger.warning(f"No AI providers allowed for user {user.id} in mandate {mandateId}") + raise ProviderNotAllowedException( + provider="any", + message="Keine AI-Provider für Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator." + ) + + # Check automation-level allowedProviders restriction + automationAllowedProviders = getattr(self.services, 'allowedProviders', None) + if automationAllowedProviders: + # Filter by both RBAC and automation-level restrictions + effectiveProviders = [p for p in automationAllowedProviders if p in rbacAllowedProviders] + if not effectiveProviders: + logger.warning(f"No providers available after automation restriction. " + f"Automation allows: {automationAllowedProviders}, " + f"RBAC allows: {rbacAllowedProviders}") + raise ProviderNotAllowedException( + provider="any", + message="Die konfigurierten AI-Provider dieser Automation sind für Ihre Rolle nicht freigegeben." + ) + logger.debug(f"Automation provider check passed: {effectiveProviders}") + + # Check if preferred provider (from UI selection) is allowed + preferredProvider = getattr(self.services, 'preferredProvider', None) + if preferredProvider: + if preferredProvider not in rbacAllowedProviders: + logger.warning(f"Preferred provider {preferredProvider} not allowed for user {user.id}") + raise ProviderNotAllowedException( + provider=preferredProvider, + message=f"Der gewählte Provider '{preferredProvider}' ist für Ihre Rolle nicht freigegeben." + ) + logger.debug(f"Preferred provider {preferredProvider} is allowed") + + logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed") + + except InsufficientBalanceException: + raise # Re-raise billing exceptions + except ProviderNotAllowedException: + raise # Re-raise provider exceptions + except Exception as e: + # Log but don't block on billing check errors + logger.warning(f"Billing check failed with error (non-blocking): {e}") async def ensureAiObjectsInitialized(self): """Ensure aiObjects is initialized and submodules are ready.""" diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index e7c6839e..99a89df1 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -24,7 +24,7 @@ from .subAutomationUtils import parseScheduleToCron, planToPrompt, replacePlaceh logger = logging.getLogger(__name__) -async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None) -> ChatWorkflow: +async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow: """ Starts a new chat or continues an existing one, then launches processing asynchronously. @@ -34,12 +34,24 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode workflowId: Optional workflow ID to continue existing workflow workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution mandateId: Mandate ID from request context (required for proper data isolation) + featureInstanceId: Feature instance ID for context Example usage for Dynamic mode: workflow = await chatStart(currentUser, userInput, workflowMode=WorkflowModeEnum.WORKFLOW_DYNAMIC, mandateId=mandateId) """ try: services = getServices(currentUser, mandateId=mandateId) + + # Store preferred provider in services context for billing/model selection + if hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: + services.preferredProvider = userInput.preferredProvider + logger.debug(f"Using preferred provider: {userInput.preferredProvider}") + + # Store feature instance ID in services context + if featureInstanceId: + services.featureInstanceId = featureInstanceId + services.featureCode = 'chatplayground' + workflowManager = WorkflowManager(services) workflow = await workflowManager.workflowStart(userInput, workflowMode, workflowId) return workflow @@ -84,6 +96,14 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: executionLog["messages"].append(f"Started execution at {executionStartTime}") + # Store allowed providers from automation in services context + if hasattr(automation, 'allowedProviders') and automation.allowedProviders: + services.allowedProviders = automation.allowedProviders + logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}") + + # Store feature context for billing + services.featureCode = 'automation' + # 2. Replace placeholders in template to generate plan template = automation.template or "" placeholders = automation.placeholders or {} From d5226a5599206ece56235b2a64e48ba136a14cb2 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Wed, 4 Feb 2026 22:34:41 +0100 Subject: [PATCH 04/18] boot running without errors --- modules/interfaces/interfaceDbBilling.py | 43 +++++--- modules/routes/routeBilling.py | 98 ++++++++----------- .../serviceBilling/mainServiceBilling.py | 13 +-- 3 files changed, 77 insertions(+), 77 deletions(-) diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index ebbe3ae5..141fc118 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -92,11 +92,11 @@ class BillingObjects: def _initializeDatabase(self): """Initialize database connection.""" self.db = DatabaseConnector( - databaseName=BILLING_DATABASE, - host=APP_CONFIG.get('Database_Host', 'localhost'), - port=int(APP_CONFIG.get('Database_Port', '5432')), - user=APP_CONFIG.get('Database_User', 'admin'), - password=APP_CONFIG.get('Database_Password', 'admin') + dbDatabase=BILLING_DATABASE, + dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), + dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), + dbUser=APP_CONFIG.get('DB_USER'), + dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') ) def setUserContext(self, currentUser: User, mandateId: str = None): @@ -128,7 +128,7 @@ class BillingObjects: try: results = self.db.getRecordset( BillingSettings, - filterDict={"mandateId": mandateId} + recordFilter={"mandateId": mandateId} ) return results[0] if results else None except Exception as e: @@ -200,7 +200,7 @@ class BillingObjects: try: results = self.db.getRecordset( BillingAccount, - filterDict={"id": accountId} + recordFilter={"id": accountId} ) return results[0] if results else None except Exception as e: @@ -220,7 +220,7 @@ class BillingObjects: try: results = self.db.getRecordset( BillingAccount, - filterDict={ + recordFilter={ "mandateId": mandateId, "accountType": AccountTypeEnum.MANDATE.value } @@ -244,7 +244,7 @@ class BillingObjects: try: results = self.db.getRecordset( BillingAccount, - filterDict={ + recordFilter={ "mandateId": mandateId, "userId": userId, "accountType": AccountTypeEnum.USER.value @@ -255,6 +255,25 @@ class BillingObjects: logger.error(f"Error getting user account: {e}") return None + def getAccountsByMandate(self, mandateId: str) -> List[Dict[str, Any]]: + """ + Get all billing accounts for a mandate. + + Args: + mandateId: Mandate ID + + Returns: + List of BillingAccount dicts + """ + try: + return self.db.getRecordset( + BillingAccount, + recordFilter={"mandateId": mandateId} + ) + except Exception as e: + logger.error(f"Error getting accounts for mandate: {e}") + return [] + def createAccount(self, account: BillingAccount) -> Dict[str, Any]: """ Create a new billing account. @@ -405,7 +424,7 @@ class BillingObjects: """ try: filterDict = {"accountId": accountId} - results = self.db.getRecordset(BillingTransaction, filterDict=filterDict) + results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict) # Apply date filters if provided if startDate or endDate: @@ -442,7 +461,7 @@ class BillingObjects: List of transaction dicts """ # Get all accounts for mandate - accounts = self.db.getRecordset(BillingAccount, filterDict={"mandateId": mandateId}) + accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId}) allTransactions = [] for account in accounts: @@ -616,7 +635,7 @@ class BillingObjects: "periodType": periodType.value } - results = self.db.getRecordset(UsageStatistics, filterDict=filterDict) + results = self.db.getRecordset(UsageStatistics, recordFilter=filterDict) # Filter by year filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year] diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index c191e793..e3698509 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -136,11 +136,11 @@ async def getBalance( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/balance/{mandateId}", response_model=BillingBalanceResponse) +@router.get("/balance/{targetMandateId}", response_model=BillingBalanceResponse) @limiter.limit("60/minute") async def getBalanceForMandate( request: Request, - mandateId: str = Path(..., description="Mandate ID"), + targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext) ): """ @@ -149,7 +149,7 @@ async def getBalanceForMandate( try: billingService = getBillingService( ctx.currentUser, - mandateId, + targetMandateId, featureCode="billing" ) @@ -158,12 +158,12 @@ async def getBalanceForMandate( # Get mandate name from app interface from modules.interfaces.interfaceDbApp import getInterface as getAppInterface - appInterface = getAppInterface(ctx.currentUser, mandateId=mandateId) - mandate = appInterface.getMandate(mandateId) + appInterface = getAppInterface(ctx.currentUser, mandateId=targetMandateId) + mandate = appInterface.getMandate(targetMandateId) mandateName = mandate.get("name", "") if mandate else "" return BillingBalanceResponse( - mandateId=mandateId, + mandateId=targetMandateId, mandateName=mandateName, billingModel=checkResult.billingModel or BillingModelEnum.UNLIMITED, balance=checkResult.currentBalance or 0.0, @@ -173,7 +173,7 @@ async def getBalanceForMandate( ) except Exception as e: - logger.error(f"Error getting billing balance for mandate {mandateId}: {e}") + logger.error(f"Error getting billing balance for mandate {targetMandateId}: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -332,20 +332,20 @@ async def getAllowedProviders( # Admin Endpoints # ============================================================================= -@router.get("/admin/settings/{mandateId}", response_model=Dict[str, Any]) +@router.get("/admin/settings/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -@requireSysAdmin async def getSettingsAdmin( request: Request, - mandateId: str = Path(..., description="Mandate ID"), - ctx: RequestContext = Depends(getRequestContext) + targetMandateId: str = Path(..., description="Mandate ID"), + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) ): """ Get billing settings for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, mandateId) - settings = billingInterface.getSettings(mandateId) + billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + settings = billingInterface.getSettings(targetMandateId) if not settings: raise HTTPException(status_code=404, detail="Billing settings not found") @@ -359,21 +359,21 @@ async def getSettingsAdmin( raise HTTPException(status_code=500, detail=str(e)) -@router.post("/admin/settings/{mandateId}", response_model=Dict[str, Any]) +@router.post("/admin/settings/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -@requireSysAdmin async def createOrUpdateSettings( request: Request, - mandateId: str = Path(..., description="Mandate ID"), + targetMandateId: str = Path(..., description="Mandate ID"), settingsUpdate: BillingSettingsUpdate = Body(...), - ctx: RequestContext = Depends(getRequestContext) + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) ): """ Create or update billing settings for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, mandateId) - existingSettings = billingInterface.getSettings(mandateId) + billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + existingSettings = billingInterface.getSettings(targetMandateId) if existingSettings: # Update existing settings @@ -387,7 +387,7 @@ async def createOrUpdateSettings( from modules.datamodels.datamodelBilling import BillingSettings newSettings = BillingSettings( - mandateId=mandateId, + mandateId=targetMandateId, billingModel=settingsUpdate.billingModel or BillingModelEnum.UNLIMITED, defaultUserCredit=settingsUpdate.defaultUserCredit or 10.0, warningThresholdPercent=settingsUpdate.warningThresholdPercent or 10.0, @@ -406,14 +406,14 @@ async def createOrUpdateSettings( raise HTTPException(status_code=500, detail=str(e)) -@router.post("/admin/credit/{mandateId}", response_model=Dict[str, Any]) +@router.post("/admin/credit/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -@requireSysAdmin async def addCredit( request: Request, - mandateId: str = Path(..., description="Mandate ID"), + targetMandateId: str = Path(..., description="Mandate ID"), creditRequest: CreditAddRequest = Body(...), - ctx: RequestContext = Depends(getRequestContext) + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) ): """ Add credit to a billing account (SysAdmin only). @@ -421,8 +421,8 @@ async def addCredit( """ try: # Get settings to determine billing model - billingInterface = getBillingInterface(ctx.currentUser, mandateId) - settings = billingInterface.getSettings(mandateId) + billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + settings = billingInterface.getSettings(targetMandateId) if not settings: raise HTTPException(status_code=404, detail="Billing settings not found for this mandate") @@ -436,13 +436,13 @@ async def addCredit( # Create user-level account if needed and add credit account = billingInterface.getOrCreateUserAccount( - mandateId, + targetMandateId, creditRequest.userId, initialBalance=0.0 ) elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]: # Create mandate-level account if needed and add credit - account = billingInterface.getOrCreateMandateAccount(mandateId, initialBalance=0.0) + account = billingInterface.getOrCreateMandateAccount(targetMandateId, initialBalance=0.0) else: raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model") @@ -459,7 +459,7 @@ async def addCredit( result = billingInterface.createTransaction(transaction) - logger.info(f"Added {creditRequest.amount} CHF credit to account {account['id']} in mandate {mandateId}") + logger.info(f"Added {creditRequest.amount} CHF credit to account {account['id']} in mandate {targetMandateId}") return result @@ -470,34 +470,22 @@ async def addCredit( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/admin/accounts/{mandateId}", response_model=List[AccountSummary]) +@router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary]) @limiter.limit("30/minute") -@requireSysAdmin async def getAccounts( request: Request, - mandateId: str = Path(..., description="Mandate ID"), - ctx: RequestContext = Depends(getRequestContext) + targetMandateId: str = Path(..., description="Mandate ID"), + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) ): """ Get all billing accounts for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, mandateId) + billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) - # Get all accounts for this mandate - from modules.connectors.connectorDbPostgre import DatabaseConnector - from modules.shared.configuration import APP_CONFIG - from modules.datamodels.datamodelBilling import BillingAccount - - db = DatabaseConnector( - databaseName="poweron_billing", - host=APP_CONFIG.get('Database_Host', 'localhost'), - port=int(APP_CONFIG.get('Database_Port', '5432')), - user=APP_CONFIG.get('Database_User', 'admin'), - password=APP_CONFIG.get('Database_Password', 'admin') - ) - - accounts = db.getRecordset(BillingAccount, filterDict={"mandateId": mandateId}) + # Get all accounts for this mandate via interface + accounts = billingInterface.getAccountsByMandate(targetMandateId) result = [] for acc in accounts: @@ -519,21 +507,21 @@ async def getAccounts( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/admin/transactions/{mandateId}", response_model=List[TransactionResponse]) +@router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse]) @limiter.limit("30/minute") -@requireSysAdmin async def getTransactionsAdmin( request: Request, - mandateId: str = Path(..., description="Mandate ID"), + targetMandateId: str = Path(..., description="Mandate ID"), limit: int = Query(default=100, ge=1, le=1000), - ctx: RequestContext = Depends(getRequestContext) + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) ): """ Get all transactions for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, mandateId) - transactions = billingInterface.getTransactionsByMandate(mandateId, limit=limit) + billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=limit) result = [] for t in transactions: @@ -553,5 +541,5 @@ async def getTransactionsAdmin( return result except Exception as e: - logger.error(f"Error getting billing transactions for mandate {mandateId}: {e}") + logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/services/serviceBilling/mainServiceBilling.py b/modules/services/serviceBilling/mainServiceBilling.py index 709ab61a..c7a08a1c 100644 --- a/modules/services/serviceBilling/mainServiceBilling.py +++ b/modules/services/serviceBilling/mainServiceBilling.py @@ -255,17 +255,10 @@ class BillingService: try: from modules.security.rbac import RbacClass from modules.datamodels.datamodelRbac import AccessRuleContext - from modules.connectors.connectorDbPostgre import DatabaseConnector - from modules.shared.configuration import APP_CONFIG + from modules.security.rootAccess import getRootDbAppConnector - # Get database connectors - dbApp = DatabaseConnector( - databaseName="poweron_app", - host=APP_CONFIG.get('Database_Host', 'localhost'), - port=int(APP_CONFIG.get('Database_Port', '5432')), - user=APP_CONFIG.get('Database_User', 'admin'), - password=APP_CONFIG.get('Database_Password', 'admin') - ) + # Get database connector via established pattern + dbApp = getRootDbAppConnector() rbac = RbacClass(dbApp, dbApp) resourceKey = f"resource.aicore.{provider}" From bb10a46cd5db09bfdc648c7f68314a707b48d0f7 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Fri, 6 Feb 2026 10:26:54 +0100 Subject: [PATCH 05/18] integrated privateLLM --- app.py | 6 + env_dev.env | 1 + env_int.env | 1 + env_prod.env | 1 + modules/aicore/aicorePluginPrivateLlm.py | 496 +++++++++++++++++++++++ 5 files changed, 505 insertions(+) create mode 100644 modules/aicore/aicorePluginPrivateLlm.py diff --git a/app.py b/app.py index 609d0c07..9aa05093 100644 --- a/app.py +++ b/app.py @@ -404,10 +404,16 @@ def getAllowedOrigins(): return origins +# CORS origin regex pattern for wildcard subdomain support +# Matches all subdomains of poweron.swiss and poweron-center.net +CORS_ORIGIN_REGEX = r"https://.*\.(poweron\.swiss|poweron-center\.net)" + + # CORS configuration using environment variables app.add_middleware( CORSMiddleware, allow_origins=getAllowedOrigins(), + allow_origin_regex=CORS_ORIGIN_REGEX, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], allow_headers=["*"], diff --git a/env_dev.env b/env_dev.env index ac5349a7..5339bbaf 100644 --- a/env_dev.env +++ b/env_dev.env @@ -40,6 +40,7 @@ Connector_AiOpenai_API_SECRET = DEV_ENC:Z0FBQUFBQnBaSnM4TWFRRmxVQmNQblVIYmc1Y0Q3 Connector_AiAnthropic_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpENmFBWG16STFQUVZxNzZZRzRLYTA4X3lRanF1VkF4cU45OExNMzlsQmdISGFxTUxud1dXODBKcFhMVG9KNjdWVnlTTFFROVc3NDlsdlNHLUJXeG41NDBHaXhHR0VHVWl5UW9RNkVWbmlhakRKVW5pM0R4VHk0LUw0TV9LdkljNHdBLXJua21NQkl2b3l4UkVkMGN1YjBrMmJEeWtMay1jbmxrYWJNbUV0aktCXzU1djR2d2RSQXZORTNwcG92ZUVvVGMtQzQzTTVncEZTRGRtZUFIZWQ0dz09 Connector_AiPerplexity_API_SECRET = DEV_ENC:Z0FBQUFBQm82Mzk2Q1MwZ0dNcUVBcUtuRDJIcTZkMXVvYnpjM3JEMzJiT1NKSHljX282ZDIyZTJYc09VSTdVNXAtOWU2UXp5S193NTk5dHJsWlFjRjhWektFOG1DVGY4ZUhHTXMzS0RPN1lNcF9nSlVWbW5BZ1hkZDVTejl6bVZNRFVvX29xamJidWRFMmtjQmkyRUQ2RUh6UTN1aWNPSUJBPT0= Connector_AiTavily_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEQTdnUHMwd2pIaXNtMmtCTFREd0pyQXRKb1F5eGtHSnkyOGZiUnlBOFc0b3Vzcndrc3ViRm1nMDJIOEZKYWxqdWNkZGh5N0Z4R0JlQmxXSG5pVnJUR2VYckZhMWNMZ1FNeXJ3enJLVlpiblhOZTNleUg3ZzZyUzRZanFSeDlVMkI= +Connector_AiPrivateLlm_API_SECRET = jL4vyNfh_tv4rxoRaHKW88sVWNHbj32GsxuKE2A8bf0 # Microsoft Service Configuration Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c diff --git a/env_int.env b/env_int.env index 05313802..5534cbdf 100644 --- a/env_int.env +++ b/env_int.env @@ -40,6 +40,7 @@ Connector_AiOpenai_API_SECRET = INT_ENC:Z0FBQUFBQnBaSnM4MENkQ2xJVmE5WFZKUkh2SHJF Connector_AiAnthropic_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjT1ZlRWVJdVZMT3ljSFJDcFdxRFBRVkZhS204NnN5RDBlQ0tpenhTM0FFVktuWW9mWHNwRWx2dHB0eDBSZ0JFQnZKWlp6c01pVGREWHd1eGpERnU0Q2xhaks1clQ1ZXVsdnd2ZzhpNXNQS1BhY3FjSkdkVEhHalNaRGR4emhpakZncnpDQUVxOHVXQzVUWmtQc0FsYmFwTF9TSG5FOUFtWk5Ick1NcHFvY2s1T1c2WXlRUFFJZnh6TWhuaVpMYmppcDR0QUx0a0R6RXlwbGRYb1R4dzJkUT09 Connector_AiPerplexity_API_SECRET = INT_ENC:Z0FBQUFBQm82Mzk2UWZJdUFhSW8yc3RKc0tKRXphd0xWMkZOVlFpSGZ4SGhFWnk0cTF5VjlKQVZjdS1QSWdkS0pUSWw4OFU5MjUxdTVQel9aeWVIZTZ5TXRuVmFkZG0zWEdTOGdHMHpsTzI0TGlWYURKU1Q0VVpKTlhxUk5FTmN6SUJScDZ3ZldIaUJZcWpaQVRiSEpyQm9tRTNDWk9KTnZBPT0= Connector_AiTavily_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRkdkJMTDY0akhXNzZDWHVYSEt1cDZoOWEzSktneHZEV2JndTNmWlNSMV9KbFNIZmQzeVlrNE5qUEIwcUlBSGM1a0hOZ3J6djIyOVhnZzI3M1dIUkdicl9FVXF3RGktMmlEYmhnaHJfWTdGUkktSXVUSGdQMC1vSEV6VE8zR2F1SVk= +Connector_AiPrivateLlm_API_SECRET = jL4vyNfh_tv4rxoRaHKW88sVWNHbj32GsxuKE2A8bf0 # Microsoft Service Configuration Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c diff --git a/env_prod.env b/env_prod.env index 57a4e83c..a7b4512c 100644 --- a/env_prod.env +++ b/env_prod.env @@ -40,6 +40,7 @@ Connector_AiOpenai_API_SECRET = PROD_ENC:Z0FBQUFBQnBaSnM4TWJOVm4xVkx6azRlNDdxN3U Connector_AiAnthropic_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3TnhYdlhSLW5RbXJyMHFXX0V0bHhuTDlTaFJsRDl2dTdIUTFtVFAwTE8tY3hLbzNSMnVTLXd3RUZualN3MGNzc1kwOTIxVUN2WW1rYi1TendFRVVBSVNqRFVjckEzNExyTGNaUkJLMmozazUwemI1cnhrcEtZVXJrWkdaVFFramp3MWZ6RmY2aGlRMXVEYjM2M3ZlbmxMdnNCRDM1QWR0Wmd6MWVnS1I1c01nV3hRLXg3d2NTZXVfTi1Wdm16UnRyNGsyRTZ0bG9TQ1g1OFB5Z002bmQ3QT09 Connector_AiPerplexity_API_SECRET = PROD_ENC:Z0FBQUFBQm82Mzk2Q1FGRkJEUkI4LXlQbHYzT2RkdVJEcmM4WGdZTWpJTEhoeUF1NW5LUVpJdDBYN3k1WFN4a2FQSWJSQmd0U0xJbzZDTmFFN05FcXl0Z3V1OEpsZjYydV94TXVjVjVXRTRYSWdLMkd5XzZIbFV6emRCZHpuOUpQeThadE5xcDNDVGV1RHJrUEN0c1BBYXctZFNWcFRuVXhRPT0= Connector_AiTavily_API_SECRET = PROD_ENC:Z0FBQUFBQnBDM1Z3NmItcDh6V0JpcE5Jc0NlUWZqcmllRHB5eDlNZmVnUlNVenhNTm5xWExzbjJqdE1GZ0hTSUYtb2dvdWNhTnlQNmVWQ2NGVDgwZ0MwMWZBMlNKWEhzdlF3TlZzTXhCZWM4Z1Uwb18tSTRoU1JBVTVkSkJHOTJwX291b3dPaVphVFg= +Connector_AiPrivateLlm_API_SECRET = jL4vyNfh_tv4rxoRaHKW88sVWNHbj32GsxuKE2A8bf0 # Microsoft Service Configuration Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c diff --git a/modules/aicore/aicorePluginPrivateLlm.py b/modules/aicore/aicorePluginPrivateLlm.py new file mode 100644 index 00000000..3b9754d2 --- /dev/null +++ b/modules/aicore/aicorePluginPrivateLlm.py @@ -0,0 +1,496 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +AI Connector for PowerOn Private-LLM Service. + +Connects to the private-llm service running on-premise with Ollama backend. +Provides OCR and Vision capabilities via local AI models. + +Models: +- poweron-ocr-general: Text extraction and OCR (deepseek backend) +- poweron-vision-general: General vision tasks (qwen2.5vl backend) +- poweron-vision-deep: Deep vision analysis (granite3.2 backend) + +Pricing (CHF per call): +- Text models: CHF 0.010 +- Vision models: CHF 0.100 +""" + +import logging +import httpx +import time +from typing import List, Optional, Dict, Any +from fastapi import HTTPException +from modules.shared.configuration import APP_CONFIG +from .aicoreBase import BaseConnectorAi +from modules.datamodels.datamodelAi import ( + AiModel, + PriorityEnum, + ProcessingModeEnum, + OperationTypeEnum, + AiModelCall, + AiModelResponse, + createOperationTypeRatings +) + +# Configure logger +logger = logging.getLogger(__name__) + +# Pricing constants (CHF) +PRICE_TEXT_PER_CALL = 0.01 # CHF 0.010 per text model call +PRICE_VISION_PER_CALL = 0.10 # CHF 0.100 per vision model call + + +# Private-LLM Service URL (fix, nicht via env konfigurierbar) +PRIVATE_LLM_BASE_URL = "https://llm.poweron.swiss:8000" + + +def _loadConfigData(): + """Load configuration data for Private-LLM connector.""" + return { + "apiKey": APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET"), + "baseUrl": PRIVATE_LLM_BASE_URL, + } + + +class AiPrivateLlm(BaseConnectorAi): + """Connector for communication with the PowerOn Private-LLM Service.""" + + def __init__(self): + super().__init__() + # Load configuration + self.config = _loadConfigData() + self.apiKey = self.config["apiKey"] + self.baseUrl = self.config["baseUrl"] + + # HTTP client for API calls + # Timeout set to 3600 seconds (60 minutes) for large model processing + headers = {"Content-Type": "application/json"} + if self.apiKey: + headers["X-API-Key"] = self.apiKey + + self.httpClient = httpx.AsyncClient( + timeout=3600.0, + headers=headers + ) + + # Cache for service availability check + self._serviceAvailable: Optional[bool] = None + self._availableOllamaModels: Optional[List[str]] = None + self._lastAvailabilityCheck: float = 0 + self._availabilityCacheTtl: float = 60.0 # 60 seconds cache + + logger.info(f"Private-LLM Connector initialized (URL: {self.baseUrl})") + + def getConnectorType(self) -> str: + """Get the connector type identifier.""" + return "privatellm" + + def _checkServiceAvailability(self) -> Dict[str, Any]: + """ + Check if the Private-LLM service is available and which Ollama models are installed. + Uses caching to avoid excessive health checks. + + Returns: + Dict with 'serviceAvailable', 'ollamaConnected', 'availableModels' + """ + import asyncio + + currentTime = time.time() + + # Return cached result if still valid + if (self._serviceAvailable is not None and + currentTime - self._lastAvailabilityCheck < self._availabilityCacheTtl): + return { + "serviceAvailable": self._serviceAvailable, + "ollamaConnected": self._serviceAvailable, + "availableModels": self._availableOllamaModels or [] + } + + # Perform availability check + try: + # Use synchronous client for blocking check during initialization + with httpx.Client(timeout=5.0) as client: + headers = {"Content-Type": "application/json"} + if self.apiKey: + headers["X-API-Key"] = self.apiKey + + # Check health endpoint + healthResponse = client.get( + f"{self.baseUrl}/api/health", + headers=headers + ) + + if healthResponse.status_code != 200: + logger.warning(f"Private-LLM service not available: HTTP {healthResponse.status_code}") + self._serviceAvailable = False + self._availableOllamaModels = [] + self._lastAvailabilityCheck = currentTime + return {"serviceAvailable": False, "ollamaConnected": False, "availableModels": []} + + healthData = healthResponse.json() + ollamaConnected = healthData.get("ollamaConnected", False) + + if not ollamaConnected: + logger.warning("Private-LLM service available but Ollama not connected") + self._serviceAvailable = True + self._availableOllamaModels = [] + self._lastAvailabilityCheck = currentTime + return {"serviceAvailable": True, "ollamaConnected": False, "availableModels": []} + + # Check Ollama status for available models + statusResponse = client.get( + f"{self.baseUrl}/api/ollama/status", + headers=headers + ) + + if statusResponse.status_code == 200: + statusData = statusResponse.json() + self._availableOllamaModels = statusData.get("models", []) + else: + self._availableOllamaModels = [] + + self._serviceAvailable = True + self._lastAvailabilityCheck = currentTime + + logger.info(f"Private-LLM availability check: service=OK, ollama=OK, models={len(self._availableOllamaModels)}") + + return { + "serviceAvailable": True, + "ollamaConnected": True, + "availableModels": self._availableOllamaModels + } + + except httpx.ConnectError: + logger.warning(f"Private-LLM service not reachable at {self.baseUrl}") + self._serviceAvailable = False + self._availableOllamaModels = [] + self._lastAvailabilityCheck = currentTime + return {"serviceAvailable": False, "ollamaConnected": False, "availableModels": []} + except Exception as e: + logger.warning(f"Error checking Private-LLM availability: {e}") + self._serviceAvailable = False + self._availableOllamaModels = [] + self._lastAvailabilityCheck = currentTime + return {"serviceAvailable": False, "ollamaConnected": False, "availableModels": []} + + def _isModelAvailableInOllama(self, ollamaModelName: str, availableModels: List[str]) -> bool: + """ + Check if a model is available in Ollama. + Handles model name variations (with/without tags). + """ + if not availableModels: + return False + + # Direct match + if ollamaModelName in availableModels: + return True + + # Check without tag (e.g., "qwen2.5vl:72b" -> "qwen2.5vl") + baseModelName = ollamaModelName.split(":")[0] + for availModel in availableModels: + availBase = availModel.split(":")[0] + if baseModelName == availBase: + return True + + return False + + def getModels(self) -> List[AiModel]: + """ + Get all available Private-LLM models. + + Checks service availability and returns only models that are actually available + in the connected Ollama instance. Returns empty list if service is not reachable. + """ + # Check service availability + availability = self._checkServiceAvailability() + + if not availability["serviceAvailable"]: + logger.warning("Private-LLM service not available - no models returned") + return [] + + if not availability["ollamaConnected"]: + logger.warning("Private-LLM service available but Ollama not connected - no models returned") + return [] + + availableOllamaModels = availability.get("availableModels", []) + + # Define all models with their Ollama backend names + # Actual model specs (for 32GB RAM server): + # - deepseek-ocr: 3.34B params, 8K context, ~6.7GB RAM + # - qwen2.5vl:7b: 8.29B params, 125K context, ~6GB RAM + # - granite3.2-vision: 2B params, 16K context, ~2.4GB RAM + modelDefinitions = [ + # OCR Text Model (deepseek-ocr: 3.34B, 8K context) + { + "model": AiModel( + name="poweron-ocr-general", + displayName="PowerOn OCR General", + connectorType="privatellm", + apiUrl=f"{self.baseUrl}/api/analyze", + temperature=0.1, + maxTokens=4096, + contextLength=8192, # deepseek-ocr actual context: 8K + costPer1kTokensInput=0.0, # Flat rate pricing + costPer1kTokensOutput=0.0, # Flat rate pricing + speedRating=8, # Fast due to smaller model + qualityRating=8, + functionCall=self.callAiText, + priority=PriorityEnum.COST, + processingMode=ProcessingModeEnum.BASIC, + operationTypes=createOperationTypeRatings( + (OperationTypeEnum.DATA_EXTRACT, 9), + (OperationTypeEnum.DATA_ANALYSE, 7), + ), + version="deepseek-ocr", + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_TEXT_PER_CALL + ), + "ollamaModel": "deepseek-ocr" + }, + # Vision General Model (qwen2.5vl:7b: 8.29B, 125K context) + { + "model": AiModel( + name="poweron-vision-general", + displayName="PowerOn Vision General", + connectorType="privatellm", + apiUrl=f"{self.baseUrl}/api/analyze", + temperature=0.2, + maxTokens=8192, + contextLength=125000, # qwen2.5vl:7b actual context: 125K + costPer1kTokensInput=0.0, # Flat rate pricing + costPer1kTokensOutput=0.0, # Flat rate pricing + speedRating=7, + qualityRating=9, + functionCall=self.callAiVision, + priority=PriorityEnum.BALANCED, + processingMode=ProcessingModeEnum.ADVANCED, + operationTypes=createOperationTypeRatings( + (OperationTypeEnum.IMAGE_ANALYSE, 9), + (OperationTypeEnum.DATA_EXTRACT, 8), + ), + version="qwen2.5vl:7b", + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL + ), + "ollamaModel": "qwen2.5vl:7b" + }, + # Vision Deep Model (granite3.2-vision: 2B, 16K context) + { + "model": AiModel( + name="poweron-vision-deep", + displayName="PowerOn Vision Deep", + connectorType="privatellm", + apiUrl=f"{self.baseUrl}/api/analyze", + temperature=0.1, + maxTokens=4096, + contextLength=16000, # granite3.2-vision actual context: 16K + costPer1kTokensInput=0.0, # Flat rate pricing + costPer1kTokensOutput=0.0, # Flat rate pricing + speedRating=9, # Fast due to small 2B model + qualityRating=8, # Good for document understanding + functionCall=self.callAiVision, + priority=PriorityEnum.QUALITY, + processingMode=ProcessingModeEnum.DETAILED, + operationTypes=createOperationTypeRatings( + (OperationTypeEnum.IMAGE_ANALYSE, 9), + (OperationTypeEnum.DATA_EXTRACT, 9), + (OperationTypeEnum.DATA_ANALYSE, 8), + ), + version="granite3.2-vision", + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL + ), + "ollamaModel": "granite3.2-vision" + }, + ] + + # Filter models by Ollama availability + availableModels = [] + unavailableModels = [] + + for modelDef in modelDefinitions: + ollamaModelName = modelDef["ollamaModel"] + if self._isModelAvailableInOllama(ollamaModelName, availableOllamaModels): + availableModels.append(modelDef["model"]) + else: + unavailableModels.append(modelDef["model"].name) + + if unavailableModels: + logger.warning( + f"Private-LLM: {len(unavailableModels)} models not available in Ollama: {', '.join(unavailableModels)}. " + f"Install with: ollama pull " + ) + + if availableModels: + logger.info(f"Private-LLM: {len(availableModels)} models available") + else: + logger.warning("Private-LLM: No models available. Check Ollama installation.") + + return availableModels + + async def callAiText(self, modelCall: AiModelCall) -> AiModelResponse: + """ + Call the Private-LLM API for text-based analysis. + + Args: + modelCall: AiModelCall with messages + + Returns: + AiModelResponse with content and metadata + """ + try: + messages = modelCall.messages + model = modelCall.model + + # Extract prompt from messages + prompt = "" + for msg in messages: + content = msg.get("content", "") + if isinstance(content, str): + prompt += content + "\n" + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + prompt += part.get("text", "") + "\n" + + payload = { + "modelName": model.name, + "prompt": prompt.strip(), + "imageBase64": None + } + + logger.debug(f"Calling Private-LLM text API with model {model.name}") + + response = await self.httpClient.post( + model.apiUrl, + json=payload + ) + + if response.status_code != 200: + errorMessage = f"Private-LLM API error: {response.status_code} - {response.text}" + logger.error(errorMessage) + raise HTTPException(status_code=500, detail=errorMessage) + + responseJson = response.json() + + if not responseJson.get("success", False): + errorMsg = responseJson.get("error", "Unknown error") + logger.error(f"Private-LLM returned error: {errorMsg}") + return AiModelResponse( + content="", + success=False, + error=errorMsg + ) + + # Extract content from response + data = responseJson.get("data", {}) + rawResponse = responseJson.get("rawResponse", "") + + # Prefer rawResponse for full content, fall back to data + content = rawResponse if rawResponse else str(data.get("response", data)) + + return AiModelResponse( + content=content, + success=True, + modelId=model.name, + metadata={"data": data} + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error calling Private-LLM text API: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error calling Private-LLM API: {str(e)}") + + async def callAiVision(self, modelCall: AiModelCall) -> AiModelResponse: + """ + Call the Private-LLM API for vision-based analysis. + + Args: + modelCall: AiModelCall with messages containing image data + + Returns: + AiModelResponse with analysis content + """ + try: + messages = modelCall.messages + model = modelCall.model + + # Extract prompt and image from messages + prompt = "" + imageBase64 = None + + for msg in messages: + content = msg.get("content", "") + + if isinstance(content, str): + prompt += content + "\n" + elif isinstance(content, list): + for part in content: + if isinstance(part, dict): + if part.get("type") == "text": + prompt += part.get("text", "") + "\n" + elif part.get("type") == "image_url": + imageUrl = part.get("image_url", {}).get("url", "") + # Extract base64 from data URL + if imageUrl.startswith("data:"): + # Format: data:image/png;base64, + parts = imageUrl.split(",", 1) + if len(parts) == 2: + imageBase64 = parts[1] + else: + imageBase64 = imageUrl + + if not imageBase64: + logger.warning("No image provided for vision model call") + + payload = { + "modelName": model.name, + "prompt": prompt.strip(), + "imageBase64": imageBase64 + } + + logger.debug(f"Calling Private-LLM vision API with model {model.name}") + + response = await self.httpClient.post( + model.apiUrl, + json=payload + ) + + if response.status_code != 200: + errorMessage = f"Private-LLM API error: {response.status_code} - {response.text}" + logger.error(errorMessage) + raise HTTPException(status_code=500, detail=errorMessage) + + responseJson = response.json() + + if not responseJson.get("success", False): + errorMsg = responseJson.get("error", "Unknown error") + logger.error(f"Private-LLM returned error: {errorMsg}") + return AiModelResponse( + content="", + success=False, + error=errorMsg + ) + + # Extract content from response + data = responseJson.get("data", {}) + rawResponse = responseJson.get("rawResponse", "") + + # Prefer rawResponse for full content + content = rawResponse if rawResponse else str(data.get("response", data)) + + return AiModelResponse( + content=content, + success=True, + modelId=model.name, + metadata={"data": data} + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error calling Private-LLM vision API: {str(e)}", exc_info=True) + return AiModelResponse( + content="", + success=False, + error=f"Error during vision analysis: {str(e)}" + ) From 8dfb7caf922c6bc18f8a720cd570b6d31849ea56 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Fri, 6 Feb 2026 13:34:50 +0100 Subject: [PATCH 06/18] neues text model --- modules/aicore/aicorePluginPrivateLlm.py | 26 ++++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/modules/aicore/aicorePluginPrivateLlm.py b/modules/aicore/aicorePluginPrivateLlm.py index 3b9754d2..84a5a6b4 100644 --- a/modules/aicore/aicorePluginPrivateLlm.py +++ b/modules/aicore/aicorePluginPrivateLlm.py @@ -217,35 +217,35 @@ class AiPrivateLlm(BaseConnectorAi): # Define all models with their Ollama backend names # Actual model specs (for 32GB RAM server): - # - deepseek-ocr: 3.34B params, 8K context, ~6.7GB RAM - # - qwen2.5vl:7b: 8.29B params, 125K context, ~6GB RAM - # - granite3.2-vision: 2B params, 16K context, ~2.4GB RAM + # - qwen2.5:7b: 7.6B params, 128K context, ~4.7GB RAM (Text) + # - qwen2.5vl:7b: 8.29B params, 125K context, ~6GB RAM (Vision) + # - granite3.2-vision: 2B params, 16K context, ~2.4GB RAM (Vision) modelDefinitions = [ - # OCR Text Model (deepseek-ocr: 3.34B, 8K context) + # Text Model (qwen2.5:7b: 7.6B, 128K context) { "model": AiModel( - name="poweron-ocr-general", - displayName="PowerOn OCR General", + name="poweron-text-general", + displayName="PowerOn Text General", connectorType="privatellm", apiUrl=f"{self.baseUrl}/api/analyze", temperature=0.1, - maxTokens=4096, - contextLength=8192, # deepseek-ocr actual context: 8K + maxTokens=8192, + contextLength=128000, # qwen2.5:7b actual context: 128K costPer1kTokensInput=0.0, # Flat rate pricing costPer1kTokensOutput=0.0, # Flat rate pricing - speedRating=8, # Fast due to smaller model - qualityRating=8, + speedRating=8, # Fast and efficient + qualityRating=9, # High quality text model functionCall=self.callAiText, priority=PriorityEnum.COST, processingMode=ProcessingModeEnum.BASIC, operationTypes=createOperationTypeRatings( (OperationTypeEnum.DATA_EXTRACT, 9), - (OperationTypeEnum.DATA_ANALYSE, 7), + (OperationTypeEnum.DATA_ANALYSE, 9), ), - version="deepseek-ocr", + version="qwen2.5:7b", calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_TEXT_PER_CALL ), - "ollamaModel": "deepseek-ocr" + "ollamaModel": "qwen2.5:7b" }, # Vision General Model (qwen2.5vl:7b: 8.29B, 125K context) { From a054d12d542a07f7a7cbad24ee8fe8e2c07cb90b Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Fri, 6 Feb 2026 16:18:37 +0100 Subject: [PATCH 07/18] billing fixes --- app.py | 33 ++++++ modules/aicore/aicorePluginPrivateLlm.py | 28 ++--- modules/datamodels/datamodelChat.py | 5 +- .../routeFeatureChatplayground.py | 2 +- modules/interfaces/interfaceBootstrap.py | 49 +++++--- modules/interfaces/interfaceDbApp.py | 35 ++++++ modules/interfaces/interfaceDbBilling.py | 110 +++++++++++++++++- modules/routes/routeBilling.py | 88 ++++++++++++-- modules/services/serviceAi/mainServiceAi.py | 21 ++-- modules/workflows/automation/mainWorkflow.py | 18 ++- modules/workflows/workflowManager.py | 1 + 11 files changed, 332 insertions(+), 58 deletions(-) diff --git a/app.py b/app.py index 9aa05093..0acdcfce 100644 --- a/app.py +++ b/app.py @@ -312,6 +312,39 @@ async def lifespan(app: FastAPI): # Register audit log cleanup scheduler from modules.shared.auditLogger import registerAuditLogCleanupScheduler registerAuditLogCleanupScheduler() + + # Ensure billing settings and accounts exist + try: + from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface + from modules.datamodels.datamodelBilling import BillingSettings, BillingModelEnum + + billingInterface = getBillingRootInterface() + + # Ensure root mandate has billing settings + rootMandate = rootInterface.getRootMandate() + if rootMandate: + rootMandateId = rootMandate.get("id") if isinstance(rootMandate, dict) else getattr(rootMandate, "id", None) + if rootMandateId: + existingSettings = billingInterface.getSettings(rootMandateId) + if not existingSettings: + settings = BillingSettings( + mandateId=rootMandateId, + billingModel=BillingModelEnum.PREPAY_USER, + defaultUserCredit=10.0, + warningThresholdPercent=10.0, + blockOnZeroBalance=True, + notifyOnWarning=True + ) + billingInterface.createSettings(settings) + logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + + # Efficient bulk check: Ensure all users have billing accounts (3 queries total) + accountsCreated = billingInterface.ensureAllUserAccountsExist() + if accountsCreated > 0: + logger.info(f"Billing startup: Created {accountsCreated} missing user accounts") + + except Exception as e: + logger.warning(f"Failed to ensure billing settings/accounts (non-critical): {e}") yield diff --git a/modules/aicore/aicorePluginPrivateLlm.py b/modules/aicore/aicorePluginPrivateLlm.py index 84a5a6b4..dfcb7eaf 100644 --- a/modules/aicore/aicorePluginPrivateLlm.py +++ b/modules/aicore/aicorePluginPrivateLlm.py @@ -216,12 +216,14 @@ class AiPrivateLlm(BaseConnectorAi): availableOllamaModels = availability.get("availableModels", []) # Define all models with their Ollama backend names - # Actual model specs (for 32GB RAM server): - # - qwen2.5:7b: 7.6B params, 128K context, ~4.7GB RAM (Text) - # - qwen2.5vl:7b: 8.29B params, 125K context, ~6GB RAM (Vision) - # - granite3.2-vision: 2B params, 16K context, ~2.4GB RAM (Vision) + # Actual model specs (for 31GB RAM + 22GB GPU server): + # Context sizes reduced to fit in available RAM + # - qwen2.5:7b: 7.6B params, ~4.7GB RAM (Text) - 8K context + # - qwen2.5vl:7b: 8.29B params, ~6GB RAM (Vision) - 4K context + # - granite3.2-vision: 2B params, ~2.4GB RAM (Vision) - 4K context + # - deepseek-ocr: ~6.7GB RAM (OCR) - 4K context modelDefinitions = [ - # Text Model (qwen2.5:7b: 7.6B, 128K context) + # Text Model (qwen2.5:7b: 7.6B) { "model": AiModel( name="poweron-text-general", @@ -229,8 +231,8 @@ class AiPrivateLlm(BaseConnectorAi): connectorType="privatellm", apiUrl=f"{self.baseUrl}/api/analyze", temperature=0.1, - maxTokens=8192, - contextLength=128000, # qwen2.5:7b actual context: 128K + maxTokens=4096, + contextLength=8192, # Reduced for RAM constraints costPer1kTokensInput=0.0, # Flat rate pricing costPer1kTokensOutput=0.0, # Flat rate pricing speedRating=8, # Fast and efficient @@ -247,7 +249,7 @@ class AiPrivateLlm(BaseConnectorAi): ), "ollamaModel": "qwen2.5:7b" }, - # Vision General Model (qwen2.5vl:7b: 8.29B, 125K context) + # Vision General Model (qwen2.5vl:7b: 8.29B) { "model": AiModel( name="poweron-vision-general", @@ -255,8 +257,8 @@ class AiPrivateLlm(BaseConnectorAi): connectorType="privatellm", apiUrl=f"{self.baseUrl}/api/analyze", temperature=0.2, - maxTokens=8192, - contextLength=125000, # qwen2.5vl:7b actual context: 125K + maxTokens=2048, + contextLength=4096, # Reduced for RAM constraints (vision needs more) costPer1kTokensInput=0.0, # Flat rate pricing costPer1kTokensOutput=0.0, # Flat rate pricing speedRating=7, @@ -273,7 +275,7 @@ class AiPrivateLlm(BaseConnectorAi): ), "ollamaModel": "qwen2.5vl:7b" }, - # Vision Deep Model (granite3.2-vision: 2B, 16K context) + # Vision Deep Model (granite3.2-vision: 2B) { "model": AiModel( name="poweron-vision-deep", @@ -281,8 +283,8 @@ class AiPrivateLlm(BaseConnectorAi): connectorType="privatellm", apiUrl=f"{self.baseUrl}/api/analyze", temperature=0.1, - maxTokens=4096, - contextLength=16000, # granite3.2-vision actual context: 16K + maxTokens=2048, + contextLength=4096, # Reduced for RAM constraints costPer1kTokensInput=0.0, # Flat rate pricing costPer1kTokensOutput=0.0, # Flat rate pricing speedRating=9, # Fast due to small 2B model diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index 02f80762..b1e73ae0 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -301,6 +301,7 @@ registerModelLabels( class ChatWorkflow(BaseModel): """Chat workflow container. User-owned, no mandate context.""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) + featureInstanceId: Optional[str] = Field(None, description="Feature instance ID for multi-tenancy isolation", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) status: str = Field(default="running", description="Current status of the workflow", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ {"value": "running", "label": {"en": "Running", "fr": "En cours"}}, {"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}}, @@ -374,6 +375,7 @@ registerModelLabels( {"en": "Chat Workflow", "fr": "Flux de travail de chat"}, { "id": {"en": "ID", "fr": "ID"}, + "featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"}, "status": {"en": "Status", "fr": "Statut"}, "name": {"en": "Name", "fr": "Nom"}, "currentRound": {"en": "Current Round", "fr": "Tour actuel"}, @@ -399,7 +401,8 @@ class UserInputRequest(BaseModel): listFileId: List[str] = Field(default_factory=list, description="List of file IDs") userLanguage: str = Field(default="en", description="User's preferred language") workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue") - preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai')") + preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders") + preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)") registerModelLabels( diff --git a/modules/features/chatplayground/routeFeatureChatplayground.py b/modules/features/chatplayground/routeFeatureChatplayground.py index 6a76e70e..cedd05d6 100644 --- a/modules/features/chatplayground/routeFeatureChatplayground.py +++ b/modules/features/chatplayground/routeFeatureChatplayground.py @@ -131,7 +131,7 @@ async def stop_workflow( # Validate access and get mandate ID mandateId = await _validateInstanceAccess(instanceId, context) - # Stop workflow + # Stop workflow (pass featureInstanceId for proper RBAC filtering) workflow = await chatStop( context.user, workflowId, diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index d6f0f063..99cca1a2 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -1273,34 +1273,57 @@ def initRootMandateBilling(mandateId: str) -> None: """ Initialize billing settings for root mandate. Root mandate uses PREPAY_USER model with 10 CHF initial credit per user. + Also creates billing accounts for all users of the mandate. Args: mandateId: Root mandate ID """ try: from modules.interfaces.interfaceDbBilling import _getRootInterface + from modules.interfaces.interfaceDbApp import getRootInterface as getAppRootInterface from modules.datamodels.datamodelBilling import BillingSettings, BillingModelEnum billingInterface = _getRootInterface() + appInterface = getAppRootInterface() # Check if settings already exist existingSettings = billingInterface.getSettings(mandateId) if existingSettings: logger.info("Billing settings for root mandate already exist") - return + else: + # Create billing settings for root mandate + settings = BillingSettings( + mandateId=mandateId, + billingModel=BillingModelEnum.PREPAY_USER, + defaultUserCredit=10.0, # 10 CHF initial credit per user + warningThresholdPercent=10.0, + blockOnZeroBalance=True, + notifyOnWarning=True + ) + + billingInterface.createSettings(settings) + logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + existingSettings = billingInterface.getSettings(mandateId) - # Create billing settings for root mandate - settings = BillingSettings( - mandateId=mandateId, - billingModel=BillingModelEnum.PREPAY_USER, - defaultUserCredit=10.0, # 10 CHF initial credit per user - warningThresholdPercent=10.0, - blockOnZeroBalance=True, - notifyOnWarning=True - ) - - billingInterface.createSettings(settings) - logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + # Create billing accounts for all users of the mandate + if existingSettings: + billingModel = existingSettings.get("billingModel", "UNLIMITED") + if billingModel == BillingModelEnum.PREPAY_USER.value: + defaultCredit = existingSettings.get("defaultUserCredit", 10.0) + userMandates = appInterface.getUserMandatesByMandate(mandateId) + accountsCreated = 0 + + for um in userMandates: + userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None) + if userId: + existingAccount = billingInterface.getUserAccount(mandateId, userId) + if not existingAccount: + billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) + accountsCreated += 1 + logger.debug(f"Created billing account for user {userId}") + + if accountsCreated > 0: + logger.info(f"Created {accountsCreated} billing accounts for root mandate users with {defaultCredit} CHF each") except Exception as e: # Don't fail bootstrap if billing init fails diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 2a872bce..1c082d33 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -1608,6 +1608,7 @@ class AppObjects: def createUserMandate(self, userId: str, mandateId: str, roleIds: List[str] = None) -> UserMandate: """ Create a UserMandate record (add user to mandate). + Also creates a billing account for the user if billing is configured for PREPAY_USER. Args: userId: User ID @@ -1641,11 +1642,45 @@ class AppObjects: ) self.db.recordCreate(UserMandateRole, userMandateRole.model_dump()) + # Create billing account for user if billing is configured + self._ensureUserBillingAccount(userId, mandateId) + cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")} return UserMandate(**cleanedRecord) except Exception as e: logger.error(f"Error creating UserMandate: {e}") raise ValueError(f"Failed to create UserMandate: {e}") + + def _ensureUserBillingAccount(self, userId: str, mandateId: str) -> None: + """ + Ensure a user has a billing account for the mandate if billing is configured. + Creates account with default credit from settings if billingModel is PREPAY_USER. + + Args: + userId: User ID + mandateId: Mandate ID + """ + try: + from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface + from modules.datamodels.datamodelBilling import BillingModelEnum + + billingInterface = getBillingRootInterface() + settings = billingInterface.getSettings(mandateId) + + if not settings: + return # No billing configured for this mandate + + billingModel = settings.get("billingModel", "UNLIMITED") + if billingModel != BillingModelEnum.PREPAY_USER.value: + return # Only create user accounts for PREPAY_USER model + + defaultCredit = settings.get("defaultUserCredit", 10.0) + billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) + logger.info(f"Created billing account for user {userId} in mandate {mandateId} with {defaultCredit} CHF") + + except Exception as e: + # Don't fail user mandate creation if billing account creation fails + logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}") def deleteUserMandate(self, userId: str, mandateId: str) -> bool: """ diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index 141fc118..b3df9bea 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -360,6 +360,98 @@ class BillingObjects: return created + def ensureAllUserAccountsExist(self) -> int: + """ + Efficiently ensure all users across all mandates have billing accounts. + Uses bulk queries to minimize database connections. + + Returns: + Number of accounts created + """ + from modules.interfaces.interfaceDbApp import getRootInterface as getAppRootInterface + + try: + appInterface = getAppRootInterface() + accountsCreated = 0 + + # Step 1: Get all billing settings in one query (only PREPAY_USER mandates need user accounts) + allSettings = self.db.getRecordset(BillingSettings) + prepayUserMandates = {} + for s in allSettings: + if s.get("billingModel") == BillingModelEnum.PREPAY_USER.value: + prepayUserMandates[s.get("mandateId")] = s.get("defaultUserCredit", 10.0) + + if not prepayUserMandates: + logger.debug("No PREPAY_USER mandates found, skipping account check") + return 0 + + # Step 2: Get all existing USER accounts in one query + allAccounts = self.db.getRecordset( + BillingAccount, + recordFilter={"accountType": AccountTypeEnum.USER.value} + ) + # Build set of existing (mandateId, userId) pairs + existingAccountKeys = set() + for acc in allAccounts: + key = (acc.get("mandateId"), acc.get("userId")) + existingAccountKeys.add(key) + + # Step 3: Get all user-mandate combinations in one query + allUserMandates = appInterface.db.getRecordset( + appInterface.db.getModel("UserMandate"), + recordFilter={"enabled": True} + ) + + # Step 4: Find missing accounts and create them + for um in allUserMandates: + mandateId = um.get("mandateId") + userId = um.get("userId") + + if not mandateId or not userId: + continue + + # Only process mandates with PREPAY_USER billing + if mandateId not in prepayUserMandates: + continue + + # Check if account already exists (in memory, no DB call) + key = (mandateId, userId) + if key in existingAccountKeys: + continue + + # Create missing account + defaultCredit = prepayUserMandates[mandateId] + account = BillingAccount( + mandateId=mandateId, + userId=userId, + accountType=AccountTypeEnum.USER, + balance=defaultCredit, + enabled=True + ) + created = self.createAccount(account) + + # Create initial credit transaction + if defaultCredit > 0: + self.createTransaction(BillingTransaction( + accountId=created["id"], + transactionType=TransactionTypeEnum.CREDIT, + amount=defaultCredit, + description="Initial credit for new user", + referenceType=ReferenceTypeEnum.SYSTEM + )) + + existingAccountKeys.add(key) # Track newly created + accountsCreated += 1 + + if accountsCreated > 0: + logger.info(f"Created {accountsCreated} missing billing accounts") + + return accountsCreated + + except Exception as e: + logger.error(f"Error ensuring user accounts exist: {e}") + return 0 + # ========================================================================= # BillingTransaction Operations # ========================================================================= @@ -502,11 +594,16 @@ class BillingObjects: # Get the relevant account if billingModel == BillingModelEnum.PREPAY_USER: account = self.getUserAccount(mandateId, userId) + # Auto-create user account if not exists (with default credit from settings) + if not account: + defaultCredit = settings.get("defaultUserCredit", 10.0) + logger.info(f"Auto-creating billing account for user {userId} in mandate {mandateId} with {defaultCredit} CHF initial credit") + account = self.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) else: account = self.getMandateAccount(mandateId) if not account: - # No account = no balance = potentially blocked + # No account (only happens for mandate-level accounts) = potentially blocked if settings.get("blockOnZeroBalance", True): return BillingCheckResult( allowed=False, @@ -713,11 +810,18 @@ class BillingObjects: userMandates = appInterface.getUserMandates(userId) for um in userMandates: - mandateId = um.get("mandateId") + # Handle both Pydantic models and dicts + mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) + if not mandateId: + continue + mandate = appInterface.getMandate(mandateId) if not mandate: continue + # Get mandate name (handle both Pydantic and dict) + mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") + settings = self.getSettings(mandateId) if not settings: continue @@ -740,7 +844,7 @@ class BillingObjects: balances.append(BillingBalanceResponse( mandateId=mandateId, - mandateName=mandate.get("name", ""), + mandateName=mandateName, billingModel=billingModel, balance=balance, warningThreshold=warningThreshold, diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index e3698509..ffeea594 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -123,7 +123,7 @@ async def getBalance( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, ctx.mandateId, featureCode="billing" ) @@ -148,7 +148,7 @@ async def getBalanceForMandate( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, targetMandateId, featureCode="billing" ) @@ -158,7 +158,7 @@ async def getBalanceForMandate( # Get mandate name from app interface from modules.interfaces.interfaceDbApp import getInterface as getAppInterface - appInterface = getAppInterface(ctx.currentUser, mandateId=targetMandateId) + appInterface = getAppInterface(ctx.user, mandateId=targetMandateId) mandate = appInterface.getMandate(targetMandateId) mandateName = mandate.get("name", "") if mandate else "" @@ -190,7 +190,7 @@ async def getTransactions( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, ctx.mandateId, featureCode="billing" ) @@ -240,7 +240,7 @@ async def getStatistics( if period == "day" and not month: raise HTTPException(status_code=400, detail="Month is required for 'day' period") - billingInterface = getBillingInterface(ctx.currentUser, ctx.mandateId) + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) settings = billingInterface.getSettings(ctx.mandateId) if not settings: @@ -256,7 +256,7 @@ async def getStatistics( # Get the relevant account if billingModel == BillingModelEnum.PREPAY_USER: - account = billingInterface.getUserAccount(ctx.mandateId, ctx.currentUser.id) + account = billingInterface.getUserAccount(ctx.mandateId, ctx.user.id) else: account = billingInterface.getMandateAccount(ctx.mandateId) @@ -316,7 +316,7 @@ async def getAllowedProviders( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, ctx.mandateId, featureCode="billing" ) @@ -344,7 +344,7 @@ async def getSettingsAdmin( Get billing settings for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) settings = billingInterface.getSettings(targetMandateId) if not settings: @@ -372,7 +372,7 @@ async def createOrUpdateSettings( Create or update billing settings for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) existingSettings = billingInterface.getSettings(targetMandateId) if existingSettings: @@ -421,7 +421,7 @@ async def addCredit( """ try: # Get settings to determine billing model - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) settings = billingInterface.getSettings(targetMandateId) if not settings: @@ -482,7 +482,7 @@ async def getAccounts( Get all billing accounts for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) # Get all accounts for this mandate via interface accounts = billingInterface.getAccountsByMandate(targetMandateId) @@ -507,6 +507,70 @@ async def getAccounts( raise HTTPException(status_code=500, detail=str(e)) +class MandateUserSummary(BaseModel): + """Summary of a user for billing admin purposes.""" + id: str + email: Optional[str] = None + firstName: Optional[str] = None + lastName: Optional[str] = None + displayName: Optional[str] = None + + +@router.get("/admin/users/{targetMandateId}", response_model=List[MandateUserSummary]) +@limiter.limit("30/minute") +async def getUsersForMandate( + request: Request, + targetMandateId: str = Path(..., description="Mandate ID"), + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) +): + """ + Get all users belonging to a mandate (SysAdmin only). + Used by billing admin to select users for credit assignment. + """ + try: + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + appInterface = getAppInterface(ctx.user, mandateId=targetMandateId) + userMandates = appInterface.getUserMandatesByMandate(targetMandateId) + + result = [] + for um in userMandates: + userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None) + if not userId: + continue + + user = appInterface.getUser(userId) + if not user: + continue + + # Handle both Pydantic models and dicts + if isinstance(user, dict): + firstName = user.get("firstName", "") + lastName = user.get("lastName", "") + email = user.get("email", "") + else: + firstName = getattr(user, "firstName", "") or "" + lastName = getattr(user, "lastName", "") or "" + email = getattr(user, "email", "") or "" + + displayName = f"{firstName} {lastName}".strip() or email or userId + + result.append(MandateUserSummary( + id=userId, + email=email, + firstName=firstName, + lastName=lastName, + displayName=displayName + )) + + return result + + except Exception as e: + logger.error(f"Error getting users for mandate {targetMandateId}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse]) @limiter.limit("30/minute") async def getTransactionsAdmin( @@ -520,7 +584,7 @@ async def getTransactionsAdmin( Get all transactions for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=limit) result = [] diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py index 81d83022..3d2f5cba 100644 --- a/modules/services/serviceAi/mainServiceAi.py +++ b/modules/services/serviceAi/mainServiceAi.py @@ -184,16 +184,17 @@ class AiService: ) logger.debug(f"Automation provider check passed: {effectiveProviders}") - # Check if preferred provider (from UI selection) is allowed - preferredProvider = getattr(self.services, 'preferredProvider', None) - if preferredProvider: - if preferredProvider not in rbacAllowedProviders: - logger.warning(f"Preferred provider {preferredProvider} not allowed for user {user.id}") - raise ProviderNotAllowedException( - provider=preferredProvider, - message=f"Der gewählte Provider '{preferredProvider}' ist für Ihre Rolle nicht freigegeben." - ) - logger.debug(f"Preferred provider {preferredProvider} is allowed") + # Check if preferred providers (from UI multiselect) are allowed + preferredProviders = getattr(self.services, 'preferredProviders', None) + if preferredProviders: + for provider in preferredProviders: + if provider not in rbacAllowedProviders: + logger.warning(f"Preferred provider {provider} not allowed for user {user.id}") + raise ProviderNotAllowedException( + provider=provider, + message=f"Der gewählte Provider '{provider}' ist für Ihre Rolle nicht freigegeben." + ) + logger.debug(f"All preferred providers are allowed: {preferredProviders}") logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed") diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index 99a89df1..6a0a00e4 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -42,10 +42,14 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode try: services = getServices(currentUser, mandateId=mandateId) - # Store preferred provider in services context for billing/model selection - if hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: - services.preferredProvider = userInput.preferredProvider - logger.debug(f"Using preferred provider: {userInput.preferredProvider}") + # Store preferred providers in services context for billing/model selection + # Support both preferredProviders (list) and legacy preferredProvider (string) + if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders: + services.preferredProviders = userInput.preferredProviders + logger.debug(f"Using preferred providers: {userInput.preferredProviders}") + elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: + services.preferredProviders = [userInput.preferredProvider] + logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}") # Store feature instance ID in services context if featureInstanceId: @@ -59,10 +63,14 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode logger.error(f"Error starting chat: {str(e)}") raise -async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None) -> ChatWorkflow: +async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow: """Stops a running chat.""" try: services = getServices(currentUser, mandateId=mandateId) + # Store feature instance ID in services context for proper RBAC filtering + if featureInstanceId: + services.featureInstanceId = featureInstanceId + services.featureCode = 'chatplayground' workflowManager = WorkflowManager(services) return await workflowManager.workflowStop(workflowId) except Exception as e: diff --git a/modules/workflows/workflowManager.py b/modules/workflows/workflowManager.py index d0c35bf1..b15c66b7 100644 --- a/modules/workflows/workflowManager.py +++ b/modules/workflows/workflowManager.py @@ -97,6 +97,7 @@ class WorkflowManager: "totalTasks": 0, "totalActions": 0, "mandateId": self.services.mandateId, + "featureInstanceId": getattr(self.services, 'featureInstanceId', None), # Feature instance ID for isolation "messageIds": [], "workflowMode": workflowMode, "maxSteps": 10 , # Set maxSteps From bbea0ff1153813c6af8a42f4539a588b9c942483 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Sun, 8 Feb 2026 00:25:48 +0100 Subject: [PATCH 08/18] revised state machine for workflow backend and ui --- app.py | 26 +- modules/datamodels/datamodelChat.py | 4 +- modules/datamodels/datamodelUam.py | 2 + .../chatbot/interfaceFeatureChatbot.py | 13 +- modules/interfaces/interfaceDbApp.py | 27 +- modules/interfaces/interfaceDbBilling.py | 410 +++++++++++++++++- modules/interfaces/interfaceDbChat.py | 5 +- modules/routes/routeBilling.py | 243 ++++++++++- modules/routes/routeSecurityGoogle.py | 4 + modules/routes/routeSecurityMsft.py | 3 + .../serviceBilling/mainServiceBilling.py | 4 +- .../mainServiceGeneration.py | 8 + modules/system/mainSystem.py | 14 +- .../composeAndDraftEmailWithContext.py | 22 +- .../methodOutlook/actions/sendDraftEmail.py | 15 +- .../methodOutlook/helpers/connection.py | 8 + .../methods/methodOutlook/methodOutlook.py | 4 +- .../workflows/processing/modes/modeDynamic.py | 42 +- modules/workflows/workflowManager.py | 43 +- 19 files changed, 795 insertions(+), 102 deletions(-) diff --git a/app.py b/app.py index 0acdcfce..df7f9306 100644 --- a/app.py +++ b/app.py @@ -313,32 +313,18 @@ async def lifespan(app: FastAPI): from modules.shared.auditLogger import registerAuditLogCleanupScheduler registerAuditLogCleanupScheduler() - # Ensure billing settings and accounts exist + # Ensure billing settings and accounts exist for all mandates try: from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface - from modules.datamodels.datamodelBilling import BillingSettings, BillingModelEnum billingInterface = getBillingRootInterface() - # Ensure root mandate has billing settings - rootMandate = rootInterface.getRootMandate() - if rootMandate: - rootMandateId = rootMandate.get("id") if isinstance(rootMandate, dict) else getattr(rootMandate, "id", None) - if rootMandateId: - existingSettings = billingInterface.getSettings(rootMandateId) - if not existingSettings: - settings = BillingSettings( - mandateId=rootMandateId, - billingModel=BillingModelEnum.PREPAY_USER, - defaultUserCredit=10.0, - warningThresholdPercent=10.0, - blockOnZeroBalance=True, - notifyOnWarning=True - ) - billingInterface.createSettings(settings) - logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + # Step 1: Ensure all mandates have billing settings (creates defaults if missing) + settingsCreated = billingInterface.ensureAllMandateSettingsExist() + if settingsCreated > 0: + logger.info(f"Billing startup: Created {settingsCreated} missing mandate billing settings") - # Efficient bulk check: Ensure all users have billing accounts (3 queries total) + # Step 2: Ensure all users have billing accounts (for PREPAY_USER mandates) accountsCreated = billingInterface.ensureAllUserAccountsExist() if accountsCreated > 0: logger.info(f"Billing startup: Created {accountsCreated} missing user accounts") diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index b1e73ae0..e2d631e8 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -12,6 +12,8 @@ import uuid class ChatStat(BaseModel): """Statistics for chat operations. User-owned, no mandate context.""" + model_config = {"populate_by_name": True, "extra": "allow"} # Allow DB system fields + id: str = Field( default_factory=lambda: str(uuid.uuid4()), description="Primary key" ) @@ -41,7 +43,7 @@ registerModelLabels( "errorCount": {"en": "Error Count", "fr": "Nombre d'erreurs"}, "process": {"en": "Process", "fr": "Processus"}, "engine": {"en": "Engine", "fr": "Moteur"}, - "priceCHF": {"en": "Price USD", "fr": "Prix USD"}, + "priceCHF": {"en": "Price CHF", "fr": "Prix CHF"}, }, ) diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py index b0e6b468..155047a2 100644 --- a/modules/datamodels/datamodelUam.py +++ b/modules/datamodels/datamodelUam.py @@ -114,6 +114,7 @@ class UserConnection(BaseModel): {"value": "none", "label": {"en": "None", "fr": "Aucun"}}, ]}) tokenExpiresAt: Optional[float] = Field(None, description="When the current token expires (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False}) + grantedScopes: Optional[List[str]] = Field(None, description="OAuth scopes granted for this connection", json_schema_extra={"frontend_type": "list", "frontend_readonly": True, "frontend_required": False}) @computed_field @computed_field @@ -146,6 +147,7 @@ registerModelLabels( "expiresAt": {"en": "Expires At", "de": "Läuft ab am", "fr": "Expire le"}, "tokenStatus": {"en": "Connection Status", "de": "Verbindungsstatus", "fr": "Statut de connexion"}, "tokenExpiresAt": {"en": "Expires At", "de": "Läuft ab am", "fr": "Expire le"}, + "grantedScopes": {"en": "Granted Scopes", "de": "Gewährte Berechtigungen", "fr": "Autorisations accordées"}, "connectionReference": {"en": "Connection Reference", "de": "Verbindungsreferenz", "fr": "Référence de connexion"}, "displayLabel": {"en": "Display Label", "de": "Anzeigebezeichnung", "fr": "Libellé d'affichage"}, }, diff --git a/modules/features/chatbot/interfaceFeatureChatbot.py b/modules/features/chatbot/interfaceFeatureChatbot.py index c1e5977e..68474898 100644 --- a/modules/features/chatbot/interfaceFeatureChatbot.py +++ b/modules/features/chatbot/interfaceFeatureChatbot.py @@ -1116,7 +1116,7 @@ class ChatObjects: # Emit message event for streaming (if event manager is available) try: - from modules.features.chatbot.eventManager import get_event_manager + from modules.features.chatbot.eventManager import get_event_manager # type: ignore event_manager = get_event_manager() message_timestamp = parseTimestamp(chat_message.publishedAt, default=getUtcTimestamp()) # Emit message event in exact chatData format: {type, createdAt, item} @@ -1514,7 +1514,7 @@ class ChatObjects: # Only emit events for chatbot workflows, not for automation or dynamic workflows if workflow.workflowMode == WorkflowModeEnum.WORKFLOW_CHATBOT: try: - from modules.features.chatbot.eventManager import get_event_manager + from modules.features.chatbot.eventManager import get_event_manager # type: ignore event_manager = get_event_manager() log_timestamp = parseTimestamp(createdLog.get("timestamp"), default=getUtcTimestamp()) # Emit log event in exact chatData format: {type, createdAt, item} @@ -1563,8 +1563,8 @@ class ChatObjects: if not stats: return [] - # Return all stats records sorted by creation time - stats.sort(key=lambda x: x.get("created_at", "")) + # Return all stats records sorted by _createdAt (system field from DB) + stats.sort(key=lambda x: x.get("_createdAt", 0)) # Ensure mandateId and featureInstanceId are set for each stat return [ChatStat(**{**stat, "mandateId": stat.get("mandateId") or self.mandateId or "", "featureInstanceId": stat.get("featureInstanceId") or self.featureInstanceId or ""}) for stat in stats] @@ -1680,11 +1680,12 @@ class ChatObjects: "item": chatLog }) - # Get stats list + # Get stats - ChatStat model now supports _createdAt via extra="allow" stats = self.getStats(workflowId) for stat in stats: # Apply timestamp filtering in Python - stat_timestamp = stat.createdAt if hasattr(stat, 'createdAt') else getUtcTimestamp() + # Use _createdAt (system field from DB, preserved via model_config extra="allow") + stat_timestamp = getattr(stat, '_createdAt', None) or getUtcTimestamp() if afterTimestamp is not None and stat_timestamp <= afterTimestamp: continue diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 1c082d33..1d8359a5 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -55,6 +55,9 @@ _gatewayInterfaces = {} # Root interface instance _rootAppObjects = None +# Bootstrap completion flag - ensures bootstrap runs only ONCE per application lifecycle +_bootstrapCompleted = False + # Password-Hashing pwdContext = CryptContext(schemes=["argon2"], deprecated="auto") @@ -200,8 +203,28 @@ class AppObjects: return simpleFields, objectFields def _initRecords(self): - """Initialize standard records if they don't exist.""" - initBootstrap(self.db) + """Initialize standard records if they don't exist. + + Uses a global flag to ensure bootstrap only runs ONCE per application lifecycle. + The flag is set BEFORE calling bootstrap to prevent recursive calls during bootstrap. + """ + global _bootstrapCompleted + + if _bootstrapCompleted: + return + + # Set flag BEFORE bootstrap to prevent recursive calls during bootstrap + _bootstrapCompleted = True + logger.info("Starting bootstrap (will only run once)") + + try: + initBootstrap(self.db) + logger.info("Bootstrap completed successfully") + except Exception as e: + # Reset flag on failure so bootstrap can be retried + _bootstrapCompleted = False + logger.error(f"Bootstrap failed: {e}") + raise def checkRbacPermission( diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index b3df9bea..bbb26d20 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -15,7 +15,8 @@ import uuid from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.shared.timeUtils import getUtcTimestamp -from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelUam import User, Mandate +from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, @@ -360,6 +361,60 @@ class BillingObjects: return created + def ensureAllMandateSettingsExist(self) -> int: + """ + Efficiently ensure all mandates have billing settings. + Creates default settings (PREPAY_USER) for mandates without settings. + Uses bulk queries to minimize database connections. + + Returns: + Number of settings created + """ + try: + settingsCreated = 0 + + # Step 1: Get all existing billing settings in one query (from billing DB) + allSettings = self.db.getRecordset(BillingSettings) + existingMandateIds = set(s.get("mandateId") for s in allSettings if s.get("mandateId")) + + # Step 2: Get all mandates from APP database (separate connection) + appDb = DatabaseConnector( + dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'), + dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), + dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), + dbUser=APP_CONFIG.get('DB_USER'), + dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') + ) + allMandates = appDb.getRecordset(Mandate, recordFilter={"enabled": True}) + + # Step 3: Create settings for mandates that don't have them + for mandate in allMandates: + mandateId = mandate.get("id") + if not mandateId or mandateId in existingMandateIds: + continue + + # Create default billing settings + settings = BillingSettings( + mandateId=mandateId, + billingModel=BillingModelEnum.PREPAY_USER, + defaultUserCredit=10.0, + warningThresholdPercent=10.0, + blockOnZeroBalance=True, + notifyOnWarning=True + ) + self.createSettings(settings) + existingMandateIds.add(mandateId) # Track newly created + settingsCreated += 1 + + if settingsCreated > 0: + logger.info(f"Created {settingsCreated} missing billing settings for mandates") + + return settingsCreated + + except Exception as e: + logger.error(f"Error ensuring mandate settings exist: {e}") + return 0 + def ensureAllUserAccountsExist(self) -> int: """ Efficiently ensure all users across all mandates have billing accounts. @@ -368,10 +423,7 @@ class BillingObjects: Returns: Number of accounts created """ - from modules.interfaces.interfaceDbApp import getRootInterface as getAppRootInterface - try: - appInterface = getAppRootInterface() accountsCreated = 0 # Step 1: Get all billing settings in one query (only PREPAY_USER mandates need user accounts) @@ -385,7 +437,7 @@ class BillingObjects: logger.debug("No PREPAY_USER mandates found, skipping account check") return 0 - # Step 2: Get all existing USER accounts in one query + # Step 2: Get all existing USER accounts in one query (from billing DB) allAccounts = self.db.getRecordset( BillingAccount, recordFilter={"accountType": AccountTypeEnum.USER.value} @@ -396,9 +448,16 @@ class BillingObjects: key = (acc.get("mandateId"), acc.get("userId")) existingAccountKeys.add(key) - # Step 3: Get all user-mandate combinations in one query - allUserMandates = appInterface.db.getRecordset( - appInterface.db.getModel("UserMandate"), + # Step 3: Get all user-mandate combinations from APP database (separate connection) + appDb = DatabaseConnector( + dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'), + dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), + dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), + dbUser=APP_CONFIG.get('DB_USER'), + dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') + ) + allUserMandates = appDb.getRecordset( + UserMandate, recordFilter={"enabled": True} ) @@ -855,3 +914,338 @@ class BillingObjects: logger.error(f"Error getting balances for user: {e}") return balances + + def getTransactionsForUser(self, userId: str, limit: int = 100) -> List[Dict[str, Any]]: + """ + Get all transactions for a user across all mandates they belong to. + + Args: + userId: User ID + limit: Maximum number of results + + Returns: + List of transaction dicts + """ + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + allTransactions = [] + + try: + appInterface = getAppInterface(self.currentUser) + userMandates = appInterface.getUserMandates(userId) + + for um in userMandates: + # Handle both Pydantic models and dicts + mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) + if not mandateId: + continue + + # Only include mandates with billing settings + settings = self.getSettings(mandateId) + if not settings: + continue + + # Get transactions for this mandate + transactions = self.getTransactionsByMandate(mandateId, limit=limit) + + # Add mandate context to each transaction + mandate = appInterface.getMandate(mandateId) + mandateName = "" + if mandate: + mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") + + for t in transactions: + t["mandateId"] = mandateId + t["mandateName"] = mandateName + allTransactions.append(t) + + except Exception as e: + logger.error(f"Error getting transactions for user: {e}") + + # Sort by creation date descending and limit + allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) + return allTransactions[:limit] + + # ========================================================================= + # Mandate View Operations (Admin-Level) + # ========================================================================= + + def getMandateBalances(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]: + """ + Get mandate-level balances. + + Args: + mandateIds: Optional list of mandate IDs to filter. If None, returns all. + + Returns: + List of mandate balance dicts + """ + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + balances = [] + + try: + appInterface = getAppInterface(self.currentUser) + + # Get settings for filtering + if mandateIds: + allSettings = [self.getSettings(mId) for mId in mandateIds] + allSettings = [s for s in allSettings if s] + else: + allSettings = self.db.getRecordset(BillingSettings) + + for settings in allSettings: + mandateId = settings.get("mandateId") + if not mandateId: + continue + + billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + + # Get mandate info + mandate = appInterface.getMandate(mandateId) + mandateName = "" + if mandate: + mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") + + # For PREPAY_MANDATE, get the mandate account balance + # For PREPAY_USER, aggregate all user balances + if billingModel == BillingModelEnum.PREPAY_MANDATE: + account = self.getMandateAccount(mandateId) + totalBalance = account.get("balance", 0.0) if account else 0.0 + userCount = 0 + elif billingModel == BillingModelEnum.PREPAY_USER: + # Get all user accounts for this mandate + userAccounts = self.db.getRecordset( + BillingAccount, + recordFilter={"mandateId": mandateId, "accountType": AccountTypeEnum.USER.value} + ) + totalBalance = sum(acc.get("balance", 0.0) for acc in userAccounts) + userCount = len(userAccounts) + else: + totalBalance = 0.0 + userCount = 0 + + balances.append({ + "mandateId": mandateId, + "mandateName": mandateName, + "billingModel": billingModel.value, + "totalBalance": totalBalance, + "userCount": userCount, + "defaultUserCredit": settings.get("defaultUserCredit", 0.0), + "warningThresholdPercent": settings.get("warningThresholdPercent", 10.0), + "blockOnZeroBalance": settings.get("blockOnZeroBalance", True) + }) + + except Exception as e: + logger.error(f"Error getting mandate balances: {e}") + + return balances + + def getMandateTransactions(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: + """ + Get all transactions for specified mandates. + + Args: + mandateIds: Optional list of mandate IDs to filter. If None, returns all. + limit: Maximum number of results + + Returns: + List of transaction dicts with mandate context + """ + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + allTransactions = [] + + try: + appInterface = getAppInterface(self.currentUser) + + # Determine which mandates to query + if mandateIds: + targetMandateIds = mandateIds + else: + allSettings = self.db.getRecordset(BillingSettings) + targetMandateIds = [s.get("mandateId") for s in allSettings if s.get("mandateId")] + + for mandateId in targetMandateIds: + transactions = self.getTransactionsByMandate(mandateId, limit=limit) + + # Get mandate name + mandate = appInterface.getMandate(mandateId) + mandateName = "" + if mandate: + mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") + + for t in transactions: + t["mandateId"] = mandateId + t["mandateName"] = mandateName + allTransactions.append(t) + + except Exception as e: + logger.error(f"Error getting mandate transactions: {e}") + + # Sort by creation date descending and limit + allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) + return allTransactions[:limit] + + # ========================================================================= + # User View Operations (User-Level with RBAC) + # ========================================================================= + + def getUserBalancesForMandates(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]: + """ + Get all user-level balances for specified mandates. + + Args: + mandateIds: Optional list of mandate IDs to filter. If None, returns all. + + Returns: + List of user balance dicts with mandate and user context + """ + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + balances = [] + + try: + appInterface = getAppInterface(self.currentUser) + + # Get all user accounts + accountFilter = {"accountType": AccountTypeEnum.USER.value} + allAccounts = self.db.getRecordset(BillingAccount, recordFilter=accountFilter) + + # Filter by mandate if specified + if mandateIds: + allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds] + + # Get all relevant settings in one query + settingsMap = {} + allSettings = self.db.getRecordset(BillingSettings) + for s in allSettings: + settingsMap[s.get("mandateId")] = s + + # Get user info efficiently + userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId"))) + userMap = {} + for userId in userIds: + user = appInterface.getUser(userId) + if user: + displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) + email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None) + userMap[userId] = displayName or email or userId + + # Get mandate info efficiently + mandateMap = {} + mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId"))) + for mandateId in mandateIdList: + mandate = appInterface.getMandate(mandateId) + if mandate: + mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") + mandateMap[mandateId] = mandateName + + for account in allAccounts: + mandateId = account.get("mandateId") + userId = account.get("userId") + + if not mandateId or not userId: + continue + + settings = settingsMap.get(mandateId) + if not settings: + continue + + balance = account.get("balance", 0.0) + warningThreshold = account.get("warningThreshold", 0.0) + + balances.append({ + "accountId": account.get("id"), + "mandateId": mandateId, + "mandateName": mandateMap.get(mandateId, ""), + "userId": userId, + "userName": userMap.get(userId, userId), + "balance": balance, + "warningThreshold": warningThreshold, + "isWarning": balance <= warningThreshold, + "enabled": account.get("enabled", True) + }) + + except Exception as e: + logger.error(f"Error getting user balances for mandates: {e}") + + return balances + + def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: + """ + Get all user-level transactions for specified mandates. + + Args: + mandateIds: Optional list of mandate IDs to filter. If None, returns all. + limit: Maximum number of results + + Returns: + List of transaction dicts with mandate and user context + """ + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + allTransactions = [] + + try: + appInterface = getAppInterface(self.currentUser) + + # Get all user accounts + accountFilter = {"accountType": AccountTypeEnum.USER.value} + allAccounts = self.db.getRecordset(BillingAccount, recordFilter=accountFilter) + + # Filter by mandate if specified + if mandateIds: + allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds] + + # Build account to user/mandate mapping + accountMap = {} + for acc in allAccounts: + accountMap[acc.get("id")] = { + "mandateId": acc.get("mandateId"), + "userId": acc.get("userId") + } + + # Get user info efficiently + userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId"))) + userMap = {} + for userId in userIds: + user = appInterface.getUser(userId) + if user: + displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) + email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None) + userMap[userId] = displayName or email or userId + + # Get mandate info efficiently + mandateMap = {} + mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId"))) + for mandateId in mandateIdList: + mandate = appInterface.getMandate(mandateId) + if mandate: + mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") + mandateMap[mandateId] = mandateName + + # Get transactions for all accounts + for account in allAccounts: + accountId = account.get("id") + if not accountId: + continue + + transactions = self.getTransactions(accountId, limit=limit) + accountInfo = accountMap.get(accountId, {}) + mandateId = accountInfo.get("mandateId") + userId = accountInfo.get("userId") + + for t in transactions: + t["mandateId"] = mandateId + t["mandateName"] = mandateMap.get(mandateId, "") + t["userId"] = userId + t["userName"] = userMap.get(userId, userId) + allTransactions.append(t) + + except Exception as e: + logger.error(f"Error getting user transactions for mandates: {e}") + + # Sort by creation date descending and limit + allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) + return allTransactions[:limit] diff --git a/modules/interfaces/interfaceDbChat.py b/modules/interfaces/interfaceDbChat.py index 0aec7fe0..0a3971a8 100644 --- a/modules/interfaces/interfaceDbChat.py +++ b/modules/interfaces/interfaceDbChat.py @@ -1634,11 +1634,12 @@ class ChatObjects: "item": chatLog }) - # Get stats list + # Get stats - ChatStat model supports _createdAt via model_config extra="allow" stats = self.getStats(workflowId) for stat in stats: # Apply timestamp filtering in Python - stat_timestamp = stat.createdAt if hasattr(stat, 'createdAt') else getUtcTimestamp() + # Use _createdAt (system field from DB, preserved via model_config extra="allow") + stat_timestamp = getattr(stat, '_createdAt', None) or getUtcTimestamp() if afterTimestamp is not None and stat_timestamp <= afterTimestamp: continue diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index ffeea594..6cc0a03b 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -74,6 +74,8 @@ class TransactionResponse(BaseModel): featureCode: Optional[str] aicoreProvider: Optional[str] createdAt: Optional[datetime] + mandateId: Optional[str] = None + mandateName: Optional[str] = None class AccountSummary(BaseModel): @@ -97,6 +99,53 @@ class UsageReportResponse(BaseModel): costByFeature: Dict[str, float] +# ============================================================================= +# Response Models for Mandate/User Views +# ============================================================================= + +class MandateBalanceResponse(BaseModel): + """Mandate-level balance summary.""" + mandateId: str + mandateName: str + billingModel: str + totalBalance: float + userCount: int + defaultUserCredit: float + warningThresholdPercent: float + blockOnZeroBalance: bool + + +class UserBalanceResponse(BaseModel): + """User-level balance summary.""" + accountId: str + mandateId: str + mandateName: str + userId: str + userName: str + balance: float + warningThreshold: float + isWarning: bool + enabled: bool + + +class UserTransactionResponse(BaseModel): + """User-level transaction with user context.""" + id: str + accountId: str + transactionType: TransactionTypeEnum + amount: float + description: str + referenceType: Optional[ReferenceTypeEnum] + workflowId: Optional[str] + featureCode: Optional[str] + aicoreProvider: Optional[str] + createdAt: Optional[datetime] + mandateId: Optional[str] = None + mandateName: Optional[str] = None + userId: Optional[str] = None + userName: Optional[str] = None + + # ============================================================================= # Router Setup # ============================================================================= @@ -186,7 +235,7 @@ async def getTransactions( ctx: RequestContext = Depends(getRequestContext) ): """ - Get transaction history for the current mandate. + Get transaction history across all mandates the user belongs to. """ try: billingService = getBillingService( @@ -195,7 +244,8 @@ async def getTransactions( featureCode="billing" ) - transactions = billingService.getTransactionHistory(limit=limit) + # Fetch enough transactions for pagination + transactions = billingService.getTransactionHistory(limit=offset + limit) # Convert to response model result = [] @@ -210,7 +260,9 @@ async def getTransactions( workflowId=t.get("workflowId"), featureCode=t.get("featureCode"), aicoreProvider=t.get("aicoreProvider"), - createdAt=t.get("_createdAt") + createdAt=t.get("_createdAt"), + mandateId=t.get("mandateId"), + mandateName=t.get("mandateName") )) return result @@ -607,3 +659,188 @@ async def getTransactionsAdmin( except Exception as e: logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}") raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Mandate View Endpoints (for Admins) +# ============================================================================= + +@router.get("/view/mandates/balances", response_model=List[MandateBalanceResponse]) +@limiter.limit("30/minute") +async def getMandateViewBalances( + request: Request, + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) +): + """ + Get mandate-level balances (SysAdmin only). + Shows aggregated balances per mandate. + """ + try: + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) + balances = billingInterface.getMandateBalances() + + return [MandateBalanceResponse(**b) for b in balances] + + except Exception as e: + logger.error(f"Error getting mandate view balances: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/view/mandates/transactions", response_model=List[TransactionResponse]) +@limiter.limit("30/minute") +async def getMandateViewTransactions( + request: Request, + limit: int = Query(default=100, ge=1, le=1000), + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) +): + """ + Get all transactions across mandates (SysAdmin only). + """ + try: + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) + transactions = billingInterface.getMandateTransactions(limit=limit) + + result = [] + for t in transactions: + result.append(TransactionResponse( + id=t.get("id"), + accountId=t.get("accountId"), + transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), + amount=t.get("amount", 0.0), + description=t.get("description", ""), + referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, + workflowId=t.get("workflowId"), + featureCode=t.get("featureCode"), + aicoreProvider=t.get("aicoreProvider"), + createdAt=t.get("_createdAt"), + mandateId=t.get("mandateId"), + mandateName=t.get("mandateName") + )) + + return result + + except Exception as e: + logger.error(f"Error getting mandate view transactions: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# User View Endpoints (RBAC-based) +# ============================================================================= + +@router.get("/view/users/balances", response_model=List[UserBalanceResponse]) +@limiter.limit("30/minute") +async def getUserViewBalances( + request: Request, + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get user-level balances. + - SysAdmin: sees all user balances across all mandates + - MandateAdmin: sees user balances for mandates they manage + - Regular user: sees only their own balances + """ + try: + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) + + # Determine which mandates the user has access to + if ctx.user.isSysAdmin: + # SysAdmin sees all + mandateIds = None + else: + # Get mandates where user is admin or has billing access + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + appInterface = getAppInterface(ctx.user) + userMandates = appInterface.getUserMandates(ctx.user.id) + + # Filter to only mandates where user has admin role + # For simplicity, we'll check if user is admin in any mandate + mandateIds = [] + for um in userMandates: + mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) + if mandateId: + mandateIds.append(mandateId) + + if not mandateIds: + return [] + + allBalances = billingInterface.getUserBalancesForMandates(mandateIds) + + # Non-admin users only see their own balances + if not ctx.user.isSysAdmin: + allBalances = [b for b in allBalances if b.get("userId") == ctx.user.id] + + return [UserBalanceResponse(**b) for b in allBalances] + + except Exception as e: + logger.error(f"Error getting user view balances: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/view/users/transactions", response_model=List[UserTransactionResponse]) +@limiter.limit("30/minute") +async def getUserViewTransactions( + request: Request, + limit: int = Query(default=100, ge=1, le=1000), + ctx: RequestContext = Depends(getRequestContext) +): + """ + Get user-level transactions. + - SysAdmin: sees all user transactions across all mandates + - MandateAdmin: sees user transactions for mandates they manage + - Regular user: sees only their own transactions + """ + try: + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) + + # Determine which mandates the user has access to + if ctx.user.isSysAdmin: + # SysAdmin sees all + mandateIds = None + else: + # Get mandates where user has access + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + appInterface = getAppInterface(ctx.user) + userMandates = appInterface.getUserMandates(ctx.user.id) + + mandateIds = [] + for um in userMandates: + mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) + if mandateId: + mandateIds.append(mandateId) + + if not mandateIds: + return [] + + allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=limit) + + # Non-admin users only see their own transactions + if not ctx.user.isSysAdmin: + allTransactions = [t for t in allTransactions if t.get("userId") == ctx.user.id] + + result = [] + for t in allTransactions: + result.append(UserTransactionResponse( + id=t.get("id"), + accountId=t.get("accountId"), + transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), + amount=t.get("amount", 0.0), + description=t.get("description", ""), + referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, + workflowId=t.get("workflowId"), + featureCode=t.get("featureCode"), + aicoreProvider=t.get("aicoreProvider"), + createdAt=t.get("_createdAt"), + mandateId=t.get("mandateId"), + mandateName=t.get("mandateName"), + userId=t.get("userId"), + userName=t.get("userName") + )) + + return result + + except Exception as e: + logger.error(f"Error getting user view transactions: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/routes/routeSecurityGoogle.py b/modules/routes/routeSecurityGoogle.py index d8ef3bef..4ee634ed 100644 --- a/modules/routes/routeSecurityGoogle.py +++ b/modules/routes/routeSecurityGoogle.py @@ -487,6 +487,10 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo connection.externalId = user_info.get("id") connection.externalUsername = user_info.get("email") connection.externalEmail = user_info.get("email") + # Store actually granted scopes for this connection + granted_scopes_list = granted_scopes.split(" ") if granted_scopes else SCOPES + connection.grantedScopes = granted_scopes_list + logger.info(f"Storing granted scopes for connection {connection_id}: {granted_scopes_list}") # Update connection record directly rootInterface.db.recordModify(UserConnection, connection_id, connection.model_dump()) diff --git a/modules/routes/routeSecurityMsft.py b/modules/routes/routeSecurityMsft.py index 68bf6fe8..0abb2f56 100644 --- a/modules/routes/routeSecurityMsft.py +++ b/modules/routes/routeSecurityMsft.py @@ -498,6 +498,9 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo connection.externalId = user_info.get("id") connection.externalUsername = user_info.get("userPrincipalName") connection.externalEmail = user_info.get("mail") + # Store granted scopes for this connection + connection.grantedScopes = SCOPES + logger.info(f"Storing granted scopes for connection {connection_id}: {SCOPES}") # Update connection record directly rootInterface.db.recordModify(UserConnection, connection_id, connection.model_dump()) diff --git a/modules/services/serviceBilling/mainServiceBilling.py b/modules/services/serviceBilling/mainServiceBilling.py index c7a08a1c..472e0b58 100644 --- a/modules/services/serviceBilling/mainServiceBilling.py +++ b/modules/services/serviceBilling/mainServiceBilling.py @@ -367,7 +367,7 @@ class BillingService: def getTransactionHistory(self, limit: int = 100) -> List[Dict[str, Any]]: """ - Get transaction history for the current mandate. + Get transaction history for the user across all mandates. Args: limit: Maximum number of transactions @@ -375,7 +375,7 @@ class BillingService: Returns: List of transactions """ - return self._billingInterface.getTransactionsByMandate(self.mandateId, limit=limit) + return self._billingInterface.getTransactionsForUser(self.currentUser.id, limit=limit) # ============================================================================ diff --git a/modules/services/serviceGeneration/mainServiceGeneration.py b/modules/services/serviceGeneration/mainServiceGeneration.py index a49b78c7..4720c9a0 100644 --- a/modules/services/serviceGeneration/mainServiceGeneration.py +++ b/modules/services/serviceGeneration/mainServiceGeneration.py @@ -74,6 +74,14 @@ class GenerationService: document_data_dict = document_data.dict() elif isinstance(document_data, dict): document_data_dict = document_data + elif isinstance(document_data, str): + # JSON-String: parsen und als dict speichern (z.B. von outlook.composeAndDraftEmailWithContext) + import json + try: + document_data_dict = json.loads(document_data) + except json.JSONDecodeError: + # Kein valides JSON - als plain text speichern + document_data_dict = {"data": document_data} else: document_data_dict = {"data": str(document_data)} diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py index e80efbe6..c48add01 100644 --- a/modules/system/mainSystem.py +++ b/modules/system/mainSystem.py @@ -96,21 +96,13 @@ NAVIGATION_SECTIONS = [ "title": {"en": "BILLING", "de": "BILLING", "fr": "FACTURATION"}, "order": 35, "items": [ - { - "id": "billing-dashboard", - "objectKey": "ui.billing.dashboard", - "label": {"en": "Balance", "de": "Guthaben", "fr": "Solde"}, - "icon": "FaWallet", - "path": "/billing", - "order": 10, - }, { "id": "billing-transactions", "objectKey": "ui.billing.transactions", - "label": {"en": "Transactions", "de": "Transaktionen", "fr": "Transactions"}, - "icon": "FaListAlt", + "label": {"en": "Billing", "de": "Billing", "fr": "Facturation"}, + "icon": "FaWallet", "path": "/billing/transactions", - "order": 20, + "order": 10, }, ], }, diff --git a/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py b/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py index 59604896..e8bc94b3 100644 --- a/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py +++ b/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py @@ -13,16 +13,17 @@ logger = logging.getLogger(__name__) async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> ActionResult: try: connectionReference = parameters.get("connectionReference") - to = parameters.get("to") + to = parameters.get("to") or [] # Optional for drafts - can save draft without recipients context = parameters.get("context") - documentList = parameters.get("documentList", []) - cc = parameters.get("cc", []) - bcc = parameters.get("bcc", []) - emailStyle = parameters.get("emailStyle", "business") - maxLength = parameters.get("maxLength", 1000) + documentList = parameters.get("documentList") or [] + cc = parameters.get("cc") or [] + bcc = parameters.get("bcc") or [] + emailStyle = parameters.get("emailStyle") or "business" + maxLength = parameters.get("maxLength") or 1000 - if not connectionReference or not to or not context: - return ActionResult.isFailure(error="connectionReference, to, and context are required") + # Only connectionReference and context are required - to is optional for drafts + if not connectionReference or not context: + return ActionResult.isFailure(error="connectionReference and context are required") # Convert single values to lists for all recipient parameters if isinstance(to, str): @@ -82,12 +83,15 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A # Escape only the user-controlled context to prevent prompt injection escaped_context = context.replace('"', '\\"').replace('\n', '\\n').replace('\r', '\\r') + # Build recipients text for prompt + recipients_text = f"Recipients: {to}" if to else "Recipients: (not specified - this is a draft)" + ai_prompt = f"""Compose an email based on this context: ------- {escaped_context} ------- -Recipients: {to} +{recipients_text} Style: {emailStyle} Max length: {maxLength} characters {doc_list_text} diff --git a/modules/workflows/methods/methodOutlook/actions/sendDraftEmail.py b/modules/workflows/methods/methodOutlook/actions/sendDraftEmail.py index 9b7fb011..15c35f44 100644 --- a/modules/workflows/methods/methodOutlook/actions/sendDraftEmail.py +++ b/modules/workflows/methods/methodOutlook/actions/sendDraftEmail.py @@ -90,15 +90,20 @@ async def sendDraftEmail(self, parameters: Dict[str, Any]) -> ActionResult: else: jsonContent = str(fileData) - # Parse JSON - handle both direct JSON and JSON wrapped in documentData + # Parse JSON - handle ActionDocument format with validationMetadata wrapper try: draftEmailData = json.loads(jsonContent) - # If the JSON contains a 'documentData' field, extract it + # ActionDocument format: { "validationMetadata": {...}, "documentData": {...} } + # Extract documentData which contains the actual draft email data if isinstance(draftEmailData, dict) and 'documentData' in draftEmailData: - documentDataStr = draftEmailData['documentData'] - if isinstance(documentDataStr, str): - draftEmailData = json.loads(documentDataStr) + documentDataContent = draftEmailData['documentData'] + # documentData should be a dict (parsed from JSON by processSingleDocument) + if isinstance(documentDataContent, dict): + draftEmailData = documentDataContent + elif isinstance(documentDataContent, str): + # Legacy/fallback: parse if still a string + draftEmailData = json.loads(documentDataContent) # Validate draft email structure if not isinstance(draftEmailData, dict): diff --git a/modules/workflows/methods/methodOutlook/helpers/connection.py b/modules/workflows/methods/methodOutlook/helpers/connection.py index 8f3daded..12621fd3 100644 --- a/modules/workflows/methods/methodOutlook/helpers/connection.py +++ b/modules/workflows/methods/methodOutlook/helpers/connection.py @@ -84,6 +84,14 @@ class ConnectionHelper: elif response.status_code == 403: logger.error("Permission denied - connection lacks necessary mail permissions") logger.error("Required scopes: Mail.ReadWrite, Mail.Send, Mail.ReadWrite.Shared") + logger.error("Solution: User must reconnect and grant mail permissions") + return False + elif response.status_code == 404: + # 404 on /me/mailFolders typically means the token lacks mail scopes + # This happens when the connection was created without mail permissions + logger.error("Mail API not accessible (404) - token likely lacks mail scopes") + logger.error("This usually means the connection was created without Mail.ReadWrite permission") + logger.error("Solution: User must delete the connection and reconnect, granting mail permissions") return False else: logger.warning(f"Permission check returned status {response.status_code}") diff --git a/modules/workflows/methods/methodOutlook/methodOutlook.py b/modules/workflows/methods/methodOutlook/methodOutlook.py index 4a978b7a..8d80cef5 100644 --- a/modules/workflows/methods/methodOutlook/methodOutlook.py +++ b/modules/workflows/methods/methodOutlook/methodOutlook.py @@ -150,8 +150,8 @@ class MethodOutlook(MethodBase): name="to", type="List[str]", frontendType=FrontendType.MULTISELECT, - required=True, - description="Recipient email addresses" + required=False, + description="Recipient email addresses (optional for drafts)" ), "context": WorkflowActionParameter( name="context", diff --git a/modules/workflows/processing/modes/modeDynamic.py b/modules/workflows/processing/modes/modeDynamic.py index 1510e512..e59a9253 100644 --- a/modules/workflows/processing/modes/modeDynamic.py +++ b/modules/workflows/processing/modes/modeDynamic.py @@ -204,29 +204,29 @@ class DynamicMode(BaseMode): if quality_score is None: quality_score = 0.0 logger.info(f"Content validation: {validationResult.get('overallSuccess', False)} (quality: {quality_score:.2f})") + + # Record validation result for adaptive learning + actionValue = selection.get('action', 'unknown') + actionContext = { + 'actionName': actionValue, + 'workflowId': context.workflowId + } + + self.adaptiveLearningEngine.recordValidationResult( + validationResult, + actionContext, + context.workflowId, + step + ) + + # Learn from feedback - use taskIntent (task-level), not workflowIntent + feedback = self._collectFeedback(result, validationResult, self.taskIntent) + self.learningEngine.learnFromFeedback(feedback, context, self.taskIntent) + + # Update progress - use taskIntent (task-level), not workflowIntent + self.progressTracker.updateOperation(result, validationResult, self.taskIntent) else: logger.info("Content validation skipped: no documents to validate") - - # NEW: Record validation result for adaptive learning - actionValue = selection.get('action', 'unknown') - actionContext = { - 'actionName': actionValue, - 'workflowId': context.workflowId - } - - self.adaptiveLearningEngine.recordValidationResult( - validationResult, - actionContext, - context.workflowId, - step - ) - - # NEW: Learn from feedback - use taskIntent (task-level), not workflowIntent - feedback = self._collectFeedback(result, validationResult, self.taskIntent) - self.learningEngine.learnFromFeedback(feedback, context, self.taskIntent) - - # NEW: Update progress - use taskIntent (task-level), not workflowIntent - self.progressTracker.updateOperation(result, validationResult, self.taskIntent) decision = await self._refineDecide(context, observation) diff --git a/modules/workflows/workflowManager.py b/modules/workflows/workflowManager.py index b15c66b7..030a966f 100644 --- a/modules/workflows/workflowManager.py +++ b/modules/workflows/workflowManager.py @@ -430,11 +430,33 @@ The following is the user's original input message. Analyze intent, normalize th workflow = self.services.workflow checkWorkflowStopped(self.services) + # Send "first" message to mark round start (consistent with full workflow) + normalizedRequest = getattr(self.services, 'currentUserPromptNormalized', None) or userInput.prompt + roundNum = workflow.currentRound + contextLabel = f"round{roundNum}_usercontext" + + firstMessageData = { + "workflowId": workflow.id, + "role": "user", + "message": normalizedRequest, + "status": "first", + "sequenceNr": len(workflow.messages) + 1, + "publishedAt": self.services.utils.timestampGetUtc(), + "documentsLabel": contextLabel, + "documents": [], + "roundNumber": roundNum, + "taskNumber": 0, + "actionNumber": 0, + "taskProgress": "pending", + "actionProgress": "pending" + } + self.services.chat.storeMessageWithDocuments(workflow, firstMessageData, []) + # Get user language if available userLanguage = getattr(self.services, 'currentUserLanguage', None) # Execute fast path - use normalizedRequest if available, otherwise use raw prompt - normalizedPrompt = getattr(self.services, 'currentUserPromptNormalized', None) or userInput.prompt + normalizedPrompt = normalizedRequest result = await self.workflowProcessor.fastPathExecute( prompt=normalizedPrompt, documents=documents, @@ -491,14 +513,6 @@ The following is the user's original input message. Analyze intent, normalize th } chatDocuments.append(chatDoc) - # Mark workflow as completed BEFORE storing message (so UI polling stops) - workflow.status = "completed" - workflow.lastActivity = self.services.utils.timestampGetUtc() - self.services.chat.updateWorkflow(workflow.id, { - "status": "completed", - "lastActivity": workflow.lastActivity - }) - # Create ChatMessage with fast path response (in user's language) messageData = { "workflowId": workflow.id, @@ -518,9 +532,18 @@ The following is the user's original input message. Analyze intent, normalize th "actionProgress": "success" } - # Store message with documents + # Store message with documents BEFORE marking workflow as completed + # This ensures UI polling sees the "last" message before status changes self.services.chat.storeMessageWithDocuments(workflow, messageData, chatDocuments) + # Mark workflow as completed AFTER storing message + workflow.status = "completed" + workflow.lastActivity = self.services.utils.timestampGetUtc() + self.services.chat.updateWorkflow(workflow.id, { + "status": "completed", + "lastActivity": workflow.lastActivity + }) + logger.info(f"Fast path completed successfully, response length: {len(responseText)} chars") except Exception as e: From e4d41965f3358a4378ab9d0b700bf0aa6dc396b5 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Sun, 8 Feb 2026 01:44:43 +0100 Subject: [PATCH 09/18] fixed stats and billing sync --- modules/datamodels/datamodelAi.py | 3 + modules/datamodels/datamodelChat.py | 3 +- modules/interfaces/interfaceAiObjects.py | 11 ++ modules/interfaces/interfaceDbChat.py | 28 ++++- modules/services/serviceAi/mainServiceAi.py | 115 +++++++++++++++--- .../services/serviceAi/subAiCallLooping.py | 12 +- .../serviceGeneration/paths/imagePath.py | 6 +- modules/workflows/automation/mainWorkflow.py | 12 +- .../workflows/processing/workflowProcessor.py | 8 +- 9 files changed, 151 insertions(+), 47 deletions(-) diff --git a/modules/datamodels/datamodelAi.py b/modules/datamodels/datamodelAi.py index c9d81bfa..4233b7d7 100644 --- a/modules/datamodels/datamodelAi.py +++ b/modules/datamodels/datamodelAi.py @@ -144,6 +144,9 @@ class AiCallOptions(BaseModel): temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)") maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch") + # Provider filtering (from UI multiselect or automation config) + allowedProviders: Optional[List[str]] = Field(default=None, description="List of allowed AI providers to use (empty = all RBAC-permitted)") + class AiCallRequest(BaseModel): """Centralized AI call request payload for interface use.""" diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index e2d631e8..fbad3d57 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -403,8 +403,7 @@ class UserInputRequest(BaseModel): listFileId: List[str] = Field(default_factory=list, description="List of file IDs") userLanguage: str = Field(default="en", description="User's preferred language") workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue") - preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders") - preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)") + allowedProviders: Optional[List[str]] = Field(None, description="List of allowed AI providers (multiselect)") registerModelLabels( diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py index 2e6e36f5..b2d91ed0 100644 --- a/modules/interfaces/interfaceAiObjects.py +++ b/modules/interfaces/interfaceAiObjects.py @@ -89,6 +89,17 @@ class AiObjects: # Get failover models for this operation type availableModels = modelRegistry.getAvailableModels() + + # Filter by allowedProviders if specified (from workflow config) + allowedProviders = getattr(options, 'allowedProviders', None) if options else None + if allowedProviders: + filteredModels = [m for m in availableModels if m.connectorType in allowedProviders] + if filteredModels: + logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})") + availableModels = filteredModels + else: + logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models") + failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels) if not failoverModelList: diff --git a/modules/interfaces/interfaceDbChat.py b/modules/interfaces/interfaceDbChat.py index 0a3971a8..9ee20fc0 100644 --- a/modules/interfaces/interfaceDbChat.py +++ b/modules/interfaces/interfaceDbChat.py @@ -1532,8 +1532,19 @@ class ChatObjects: return [] # Return all stats records sorted by creation time - stats.sort(key=lambda x: x.get("created_at", "")) - return [ChatStat(**stat) for stat in stats] + # DB uses _createdAt (camelCase system field) + stats.sort(key=lambda x: x.get("_createdAt", 0)) + + # Convert to ChatStat objects, preserving _createdAt via extra="allow" + result = [] + for stat in stats: + chat_stat = ChatStat(**stat) + # Explicitly preserve _createdAt from raw DB record + if "_createdAt" in stat: + setattr(chat_stat, '_createdAt', stat["_createdAt"]) + result.append(chat_stat) + + return result def createStat(self, statData: Dict[str, Any]) -> ChatStat: @@ -1549,9 +1560,16 @@ class ChatObjects: # Validate the stat data against ChatStat model stat = ChatStat(**statData) + logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: " + f"process={statData.get('process')}, " + f"priceCHF={statData.get('priceCHF', 0):.4f}, " + f"processingTime={statData.get('processingTime', 0):.2f}s") + # Create the stat record in the database created = self.db.recordCreate(ChatStat, stat) + logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}") + # Return the created ChatStat return ChatStat(**created) except Exception as e: @@ -1643,10 +1661,14 @@ class ChatObjects: if afterTimestamp is not None and stat_timestamp <= afterTimestamp: continue + # Convert to dict and include _createdAt for frontend + stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict() + stat_dict['_createdAt'] = stat_timestamp + items.append({ "type": "stat", "createdAt": stat_timestamp, - "item": stat + "item": stat_dict }) # Sort all items by createdAt timestamp for chronological order diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py index 3d2f5cba..5fdf32a5 100644 --- a/modules/services/serviceAi/mainServiceAi.py +++ b/modules/services/serviceAi/mainServiceAi.py @@ -94,15 +94,30 @@ class AiService: Includes billing checks: - Balance check before AI call - Provider permission check (via RBAC) + + Also stores workflow stats after each successful AI call. """ - # Billing check before AI call + # Billing check before AI call (validates RBAC permissions) await self._checkBillingBeforeAiCall() + # Calculate effective allowedProviders: RBAC ∩ Workflow + # RBAC is master - only RBAC-permitted providers can ever be used + effectiveProviders = self._calculateEffectiveProviders() + if effectiveProviders and request.options: + request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders}) + logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}") + if hasattr(request, 'contentParts') and request.contentParts: - return await self.extractionService.processContentPartsWithAi( + response = await self.extractionService.processContentPartsWithAi( request, self.aiObjects, progressCallback ) - return await self.aiObjects.callWithTextContext(request) + else: + response = await self.aiObjects.callWithTextContext(request) + + # Store workflow stats after each AI call + self._storeAiCallStats(response, request) + + return response async def _checkBillingBeforeAiCall(self) -> None: """ @@ -206,6 +221,87 @@ class AiService: # Log but don't block on billing check errors logger.warning(f"Billing check failed with error (non-blocking): {e}") + def _calculateEffectiveProviders(self) -> Optional[List[str]]: + """ + Calculate effective allowed providers: RBAC ∩ Workflow. + + RBAC is master - only RBAC-permitted providers can ever be used. + If workflow specifies allowedProviders, intersect with RBAC. + If no workflow providers, use all RBAC-permitted providers. + + Returns: + List of effective allowed providers, or None if no filtering needed + """ + try: + user = getattr(self.services, 'user', None) + mandateId = getattr(self.services, 'mandateId', None) + + if not user or not mandateId: + return None + + # Get RBAC-permitted providers (master list) + # Note: getBillingService is imported at module level from mainServiceBilling + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + billingService = getBillingService(user, mandateId, featureInstanceId, featureCode) + rbacProviders = billingService.getallowedProviders() + + if not rbacProviders: + logger.warning("No RBAC-permitted providers found") + return None + + # Get workflow-specified providers (optional filter) + workflowProviders = getattr(self.services, 'allowedProviders', None) + + if workflowProviders: + # Intersect: only providers that are both RBAC-permitted AND workflow-allowed + effectiveProviders = [p for p in workflowProviders if p in rbacProviders] + logger.debug(f"Provider filter: RBAC={rbacProviders}, Workflow={workflowProviders}, Effective={effectiveProviders}") + else: + # No workflow filter - use all RBAC-permitted providers + effectiveProviders = rbacProviders + logger.debug(f"Provider filter: RBAC={rbacProviders} (no workflow filter)") + + return effectiveProviders if effectiveProviders else None + + except Exception as e: + logger.warning(f"Error calculating effective providers: {e}") + return None + + def _storeAiCallStats(self, response, request: AiCallRequest) -> None: + """Store workflow stats after an AI call. + + This method stores the AI call statistics (cost, processing time, bytes) + to the workflow stats collection for tracking and billing purposes. + + Args: + response: AiCallResponse with cost/timing data + request: Original AiCallRequest for context + """ + try: + # Skip if no workflow context + workflow = getattr(self.services, 'workflow', None) + if not workflow or not hasattr(workflow, 'id') or not workflow.id: + logger.debug("No workflow context - skipping stats storage") + return + + # Skip if response is an error + if not response or getattr(response, 'errorCount', 0) > 0: + logger.debug("Error response - skipping stats storage") + return + + # Determine process name from operation type + opType = getattr(request.options, 'operationType', 'unknown') if request.options else 'unknown' + process = f"ai.call.{opType}" + + # Store the stat + self.services.chat.storeWorkflowStat(workflow, response, process) + logger.debug(f"Stored AI call stat: {process}, cost={getattr(response, 'priceCHF', 0):.4f} CHF") + + except Exception as e: + # Log but don't fail - stats storage is not critical + logger.debug(f"Could not store AI call stat: {str(e)}") + async def ensureAiObjectsInitialized(self): """Ensure aiObjects is initialized and submodules are ready.""" if self.aiObjects is None: @@ -428,7 +524,7 @@ Respond with ONLY a JSON object in this exact format: # Debug: persist prompt/response for analysis with context-specific naming debugPrefix = debugType if debugType else "plan" self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt") - response = await self.aiObjects.callWithTextContext(request) + response = await self.callAi(request) # Use callAi to ensure stats are stored result = response.content or "" self.services.utils.writeDebugFile(result, f"{debugPrefix}_response") return result @@ -485,16 +581,7 @@ Respond with ONLY a JSON object in this exact format: operationType=opType.value ) - # Try to store workflow stats, but don't fail if workflow is None (e.g., in chatbot context) - try: - self.services.chat.storeWorkflowStat( - self.services.workflow, - response, - f"ai.{opType.name.lower()}" - ) - except Exception as e: - # Log but don't fail - workflow might be None in some contexts (e.g., chatbot) - logger.debug(f"Could not store workflow stat (workflow may be None): {str(e)}") + # Note: Stats are now stored centrally in callAi() - no need to duplicate here self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed") self.services.chat.progressLogFinish(aiOperationId, True) diff --git a/modules/services/serviceAi/subAiCallLooping.py b/modules/services/serviceAi/subAiCallLooping.py index 5f3fb79f..2e4edc3e 100644 --- a/modules/services/serviceAi/subAiCallLooping.py +++ b/modules/services/serviceAi/subAiCallLooping.py @@ -269,17 +269,7 @@ class AiCallLooper: # Document generation - save all iteration responses self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}") - # Emit stats for this iteration (only if workflow exists and has id) - if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id: - try: - self.services.chat.storeWorkflowStat( - self.services.workflow, - response, - f"ai.call.{debugPrefix}.iteration_{iteration}" - ) - except Exception as statError: - # Don't break the main loop if stat storage fails - logger.warning(f"Failed to store workflow stat: {str(statError)}") + # Note: Stats are now stored centrally in callAi() - no need to duplicate here # Check for error response using generic error detection (errorCount > 0 or modelName == "error") if hasattr(response, 'errorCount') and response.errorCount > 0: diff --git a/modules/services/serviceGeneration/paths/imagePath.py b/modules/services/serviceGeneration/paths/imagePath.py index 1247494f..c61bc997 100644 --- a/modules/services/serviceGeneration/paths/imagePath.py +++ b/modules/services/serviceGeneration/paths/imagePath.py @@ -101,11 +101,7 @@ class ImageGenerationPath: operationType=OperationTypeEnum.IMAGE_GENERATE.value ) - self.services.chat.storeWorkflowStat( - self.services.workflow, - response, - "ai.generate.image" - ) + # Note: Stats are now stored centrally in callAi() - no need to duplicate here self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated") self.services.chat.progressLogFinish(imageOperationId, True) diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index 6a0a00e4..e63f7932 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -42,14 +42,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode try: services = getServices(currentUser, mandateId=mandateId) - # Store preferred providers in services context for billing/model selection - # Support both preferredProviders (list) and legacy preferredProvider (string) - if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders: - services.preferredProviders = userInput.preferredProviders - logger.debug(f"Using preferred providers: {userInput.preferredProviders}") - elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: - services.preferredProviders = [userInput.preferredProvider] - logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}") + # Store allowedProviders in services context for model selection + if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders: + services.allowedProviders = userInput.allowedProviders + logger.info(f"AI provider filter active: {userInput.allowedProviders}") # Store feature instance ID in services context if featureInstanceId: diff --git a/modules/workflows/processing/workflowProcessor.py b/modules/workflows/processing/workflowProcessor.py index 38763f51..a78a2270 100644 --- a/modules/workflows/processing/workflowProcessor.py +++ b/modules/workflows/processing/workflowProcessor.py @@ -14,7 +14,8 @@ from modules.workflows.processing.modes.modeDynamic import DynamicMode from modules.workflows.processing.modes.modeAutomation import AutomationMode from modules.workflows.processing.shared.stateTools import checkWorkflowStopped from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest -from modules.shared.jsonUtils import extractJsonString, repairBrokenJson +from modules.shared.jsonUtils import extractJsonString, repairBrokenJson, parseJsonWithModel +from modules.datamodels.datamodelWorkflow import UnderstandingResult if TYPE_CHECKING: from modules.datamodels.datamodelWorkflow import TaskResult @@ -477,8 +478,7 @@ class WorkflowProcessor: maxProcessingTime=15 # Fast path should complete in 15s ) - # Call AI directly (no document generation - just plain text response) - # Use callWithTextContext() for text-only calls + # Call AI via callAi() to ensure stats are stored aiRequest = AiCallRequest( prompt=fastPathPrompt, context="", @@ -486,7 +486,7 @@ class WorkflowProcessor: contentParts=None # Fast path doesn't process documents ) - aiCallResponse = await self.services.ai.aiObjects.callWithTextContext(aiRequest) + aiCallResponse = await self.services.ai.callAi(aiRequest) # Extract response content (AiCallResponse.content is a string) responseText = aiCallResponse.content if aiCallResponse.content else "" From 34deb5f23dd306e5875303d87bd72a2d498cd40d Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Sun, 8 Feb 2026 13:15:19 +0100 Subject: [PATCH 10/18] fixed billing transactions mapping and added reporting --- modules/datamodels/datamodelAi.py | 1 + modules/interfaces/interfaceAiObjects.py | 1 + modules/interfaces/interfaceDbBilling.py | 7 +- modules/routes/routeBilling.py | 290 ++++++++++++++++-- .../services/serviceChat/mainServiceChat.py | 5 +- .../mainServiceExtraction.py | 46 +-- 6 files changed, 273 insertions(+), 77 deletions(-) diff --git a/modules/datamodels/datamodelAi.py b/modules/datamodels/datamodelAi.py index 4233b7d7..44eac445 100644 --- a/modules/datamodels/datamodelAi.py +++ b/modules/datamodels/datamodelAi.py @@ -162,6 +162,7 @@ class AiCallResponse(BaseModel): content: str = Field(description="AI response content") modelName: str = Field(description="Selected model name") + provider: str = Field(default="unknown", description="AI provider / connectorType (anthropic, openai, perplexity, etc.)") priceCHF: float = Field(default=0.0, description="Calculated price in USD") processingTime: float = Field(default=0.0, description="Duration in seconds") bytesSent: int = Field(default=0, description="Input data size in bytes") diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py index b2d91ed0..3976350d 100644 --- a/modules/interfaces/interfaceAiObjects.py +++ b/modules/interfaces/interfaceAiObjects.py @@ -229,6 +229,7 @@ class AiObjects: return AiCallResponse( content=content, modelName=model.name, + provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index bbb26d20..ae8b13ec 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -1174,7 +1174,7 @@ class BillingObjects: def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: """ - Get all user-level transactions for specified mandates. + Get all transactions for specified mandates (both USER and MANDATE accounts). Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. @@ -1190,9 +1190,8 @@ class BillingObjects: try: appInterface = getAppInterface(self.currentUser) - # Get all user accounts - accountFilter = {"accountType": AccountTypeEnum.USER.value} - allAccounts = self.db.getRecordset(BillingAccount, recordFilter=accountFilter) + # Get ALL accounts (both USER and MANDATE types) to cover all billing models + allAccounts = self.db.getRecordset(BillingAccount) # Filter by mandate if specified if mandateIds: diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index 6cc0a03b..bd47c791 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -22,6 +22,8 @@ from modules.auth import limiter, requireSysAdmin, getRequestContext, RequestCon # Import billing components from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService +from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict +from modules.routes.routeDataUsers import _applyFiltersAndSort from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, @@ -779,22 +781,212 @@ async def getUserViewBalances( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/view/users/transactions", response_model=List[UserTransactionResponse]) +class ViewStatisticsResponse(BaseModel): + """Aggregated statistics across all user's mandates.""" + totalCost: float = 0.0 + transactionCount: int = 0 + costByProvider: Dict[str, float] = {} + costByFeature: Dict[str, float] = {} + costByMandate: Dict[str, float] = {} + timeSeries: List[Dict[str, Any]] = [] + + +@router.get("/view/statistics") +@limiter.limit("30/minute") +async def getUserViewStatistics( + request: Request, + period: str = Query(default="month", description="Period: 'day' or 'month'"), + year: int = Query(default=None, description="Year"), + month: Optional[int] = Query(None, description="Month (1-12, required for period='day')"), + ctx: RequestContext = Depends(getRequestContext) +) -> ViewStatisticsResponse: + """ + Get aggregated usage statistics across all user's mandates. + - period='month': returns monthly time series for the given year + - period='day': returns daily time series for the given month/year + """ + try: + from datetime import timedelta + + if year is None: + year = datetime.now().year + + if period == "day" and not month: + month = datetime.now().month + + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) + + # Get all mandates the user has access to + if ctx.user.isSysAdmin: + mandateIds = None + else: + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + appInterface = getAppInterface(ctx.user) + userMandates = appInterface.getUserMandates(ctx.user.id) + mandateIds = [] + for um in userMandates: + mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) + if mandateId: + mandateIds.append(mandateId) + if not mandateIds: + logger.warning("No mandate IDs found for user") + return ViewStatisticsResponse() + + # Get all transactions + allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000) + logger.info(f"View statistics: {len(allTransactions)} total transactions fetched for period={period}, year={year}, month={month}") + + # Calculate date range + if period == "day": + startDate = date(year, month, 1) + if month == 12: + endDate = date(year + 1, 1, 1) + else: + endDate = date(year, month + 1, 1) + else: + startDate = date(year, 1, 1) + endDate = date(year + 1, 1, 1) + + # Filter by date range and only DEBIT transactions + debits = [] + skippedNoDate = 0 + skippedDateRange = 0 + skippedNotDebit = 0 + + for t in allTransactions: + createdAt = t.get("_createdAt") + if not createdAt: + skippedNoDate += 1 + continue + + # Parse date from various formats (DB stores as DOUBLE PRECISION / Unix timestamp) + txDate = None + if isinstance(createdAt, (int, float)): + txDate = datetime.fromtimestamp(createdAt).date() + elif isinstance(createdAt, datetime): + txDate = createdAt.date() + elif isinstance(createdAt, date) and not isinstance(createdAt, datetime): + txDate = createdAt + elif isinstance(createdAt, str): + try: + # Try as float string first (Unix timestamp) + txDate = datetime.fromtimestamp(float(createdAt)).date() + except (ValueError, TypeError): + try: + txDate = datetime.fromisoformat(createdAt.replace("Z", "+00:00")).date() + except (ValueError, TypeError): + skippedNoDate += 1 + continue + else: + skippedNoDate += 1 + continue + + if txDate < startDate or txDate >= endDate: + skippedDateRange += 1 + continue + + # Compare transactionType - handle both string and enum + txType = t.get("transactionType") + txTypeStr = str(txType) if txType is not None else "" + if txTypeStr != "DEBIT" and txTypeStr != "TransactionTypeEnum.DEBIT": + # Also check .value for enum objects + txTypeValue = getattr(txType, 'value', txTypeStr) + if txTypeValue != "DEBIT": + skippedNotDebit += 1 + continue + + t["_txDate"] = txDate + debits.append(t) + + logger.info(f"View statistics: {len(debits)} DEBIT transactions after filter. " + f"Skipped: noDate={skippedNoDate}, dateRange={skippedDateRange}, notDebit={skippedNotDebit}") + + # Aggregate totals + totalCost = sum(t.get("amount", 0) for t in debits) + + costByProvider: Dict[str, float] = {} + costByFeature: Dict[str, float] = {} + costByMandate: Dict[str, float] = {} + + for t in debits: + provider = t.get("aicoreProvider") or "unknown" + costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0) + + mandate = t.get("mandateName") or t.get("mandateId") or "unknown" + featureCode = t.get("featureCode") or "unknown" + featureKey = f"{mandate} / {featureCode}" + costByFeature[featureKey] = costByFeature.get(featureKey, 0) + t.get("amount", 0) + + mandate = t.get("mandateName") or t.get("mandateId") or "unknown" + costByMandate[mandate] = costByMandate.get(mandate, 0) + t.get("amount", 0) + + # Build time series (raw data only, no display logic) + timeSeries = [] + if period == "day": + numDays = (endDate - startDate).days + for day in range(numDays): + d = startDate + timedelta(days=day) + dayCost = sum(t.get("amount", 0) for t in debits if t["_txDate"] == d) + dayCount = sum(1 for t in debits if t["_txDate"] == d) + if dayCost > 0 or dayCount > 0: + timeSeries.append({ + "date": d.isoformat(), + "cost": round(dayCost, 4), + "count": dayCount + }) + else: + for m in range(1, 13): + mStart = date(year, m, 1) + mEnd = date(year, m + 1, 1) if m < 12 else date(year + 1, 1, 1) + monthCost = sum(t.get("amount", 0) for t in debits if mStart <= t["_txDate"] < mEnd) + monthCount = sum(1 for t in debits if mStart <= t["_txDate"] < mEnd) + timeSeries.append({ + "date": f"{year}-{m:02d}", + "cost": round(monthCost, 4), + "count": monthCount + }) + + return ViewStatisticsResponse( + totalCost=round(totalCost, 4), + transactionCount=len(debits), + costByProvider=costByProvider, + costByFeature=costByFeature, + costByMandate=costByMandate, + timeSeries=timeSeries + ) + + except Exception as e: + logger.error(f"Error getting view statistics: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/view/users/transactions", response_model=PaginatedResponse[UserTransactionResponse]) @limiter.limit("30/minute") async def getUserViewTransactions( request: Request, - limit: int = Query(default=100, ge=1, le=1000), + pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), ctx: RequestContext = Depends(getRequestContext) -): +) -> PaginatedResponse[UserTransactionResponse]: """ - Get user-level transactions. + Get user-level transactions with pagination support. - SysAdmin: sees all user transactions across all mandates - MandateAdmin: sees user transactions for mandates they manage - Regular user: sees only their own transactions + + Query Parameters: + - pagination: JSON-encoded PaginationParams object, or None for no pagination """ try: billingInterface = getBillingInterface(ctx.user, ctx.mandateId) + # Parse pagination params + paginationParams = None + if pagination: + import json + paginationDict = json.loads(pagination) + paginationDict = normalize_pagination_dict(paginationDict) + paginationParams = PaginationParams(**paginationDict) + # Determine which mandates the user has access to if ctx.user.isSysAdmin: # SysAdmin sees all @@ -812,34 +1004,78 @@ async def getUserViewTransactions( mandateIds.append(mandateId) if not mandateIds: - return [] + return PaginatedResponse(items=[], pagination=None) - allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=limit) + allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000) - # Non-admin users only see their own transactions - if not ctx.user.isSysAdmin: - allTransactions = [t for t in allTransactions if t.get("userId") == ctx.user.id] + logger.debug(f"Found {len(allTransactions)} transactions for mandates {mandateIds}") - result = [] + # Convert to response objects as dicts for filtering/sorting + transactionDicts = [] for t in allTransactions: - result.append(UserTransactionResponse( - id=t.get("id"), - accountId=t.get("accountId"), - transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), - amount=t.get("amount", 0.0), - description=t.get("description", ""), - referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, - workflowId=t.get("workflowId"), - featureCode=t.get("featureCode"), - aicoreProvider=t.get("aicoreProvider"), - createdAt=t.get("_createdAt"), - mandateId=t.get("mandateId"), - mandateName=t.get("mandateName"), - userId=t.get("userId"), - userName=t.get("userName") - )) + transactionDicts.append({ + "id": t.get("id"), + "accountId": t.get("accountId"), + "transactionType": t.get("transactionType", "DEBIT"), + "amount": t.get("amount", 0.0), + "description": t.get("description", ""), + "referenceType": t.get("referenceType"), + "workflowId": t.get("workflowId"), + "featureCode": t.get("featureCode"), + "aicoreProvider": t.get("aicoreProvider"), + "createdAt": t.get("_createdAt"), + "mandateId": t.get("mandateId"), + "mandateName": t.get("mandateName"), + "userId": t.get("userId"), + "userName": t.get("userName"), + }) - return result + # Apply filters and sorting + filteredDicts = _applyFiltersAndSort(transactionDicts, paginationParams) + + # Convert to response models + def _toResponse(d): + return UserTransactionResponse( + id=d.get("id"), + accountId=d.get("accountId"), + transactionType=TransactionTypeEnum(d.get("transactionType", "DEBIT")), + amount=d.get("amount", 0.0), + description=d.get("description", ""), + referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None, + workflowId=d.get("workflowId"), + featureCode=d.get("featureCode"), + aicoreProvider=d.get("aicoreProvider"), + createdAt=d.get("createdAt"), + mandateId=d.get("mandateId"), + mandateName=d.get("mandateName"), + userId=d.get("userId"), + userName=d.get("userName") + ) + + if paginationParams: + import math + totalItems = len(filteredDicts) + totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 + startIdx = (paginationParams.page - 1) * paginationParams.pageSize + endIdx = startIdx + paginationParams.pageSize + paginatedDicts = filteredDicts[startIdx:endIdx] + + return PaginatedResponse( + items=[_toResponse(d) for d in paginatedDicts], + pagination=PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=totalItems, + totalPages=totalPages, + sort=paginationParams.sort, + filters=paginationParams.filters + ) + ) + else: + return PaginatedResponse( + items=[_toResponse(d) for d in filteredDicts], + pagination=None + ) except Exception as e: logger.error(f"Error getting user view transactions: {e}") diff --git a/modules/services/serviceChat/mainServiceChat.py b/modules/services/serviceChat/mainServiceChat.py index 37b232b8..055e34cd 100644 --- a/modules/services/serviceChat/mainServiceChat.py +++ b/modules/services/serviceChat/mainServiceChat.py @@ -729,9 +729,8 @@ class ChatService: if not priceCHF or priceCHF <= 0: return - # Extract provider from model name (e.g., "anthropic.claude-3-sonnet" -> "anthropic") - modelName = getattr(aiResponse, 'modelName', '') or '' - aicoreProvider = modelName.split('.')[0] if '.' in modelName else 'unknown' + # Get provider from AiCallResponse (set from model.connectorType) + aicoreProvider = getattr(aiResponse, 'provider', None) or 'unknown' # Get feature context if available featureInstanceId = getattr(self.services, 'featureInstanceId', None) diff --git a/modules/services/serviceExtraction/mainServiceExtraction.py b/modules/services/serviceExtraction/mainServiceExtraction.py index 9ee9e739..e0bed993 100644 --- a/modules/services/serviceExtraction/mainServiceExtraction.py +++ b/modules/services/serviceExtraction/mainServiceExtraction.py @@ -229,6 +229,7 @@ class ExtractionService: aiResponse = AiCallResponse( content="", # No content for extraction stats needed modelName=model.name, + provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=bytesSent, @@ -1311,6 +1312,7 @@ class ExtractionService: return AiCallResponse( content=modelResponse.content, modelName=model.name, + provider=model.connectorType, priceCHF=0.0, processingTime=processingTime, bytesSent=0, @@ -1416,6 +1418,7 @@ class ExtractionService: return AiCallResponse( content=mergedContent, modelName=model.name, + provider=model.connectorType, priceCHF=sum(r.priceCHF for r in chunkResults), processingTime=sum(r.processingTime for r in chunkResults), bytesSent=sum(r.bytesSent for r in chunkResults), @@ -1428,49 +1431,6 @@ class ExtractionService: response = await aiObjects._callWithModel(model, prompt, contentPart.data, options) logger.info(f"✅ Content part processed successfully with model: {model.name}") return response - chunks = await self.chunkContentPartForAi(contentPart, model, options, prompt) - if not chunks: - raise ValueError(f"Failed to chunk content part for model {model.name}") - - logger.info(f"Starting to process {len(chunks)} chunks with model {model.name}") - - if progressCallback: - progressCallback(0.0, f"Starting to process {len(chunks)} chunks") - - chunkResults = [] - for idx, chunk in enumerate(chunks): - chunkNum = idx + 1 - chunkData = chunk.get('data', '') - logger.info(f"Processing chunk {chunkNum}/{len(chunks)} with model {model.name}") - - if progressCallback: - progressCallback(chunkNum / len(chunks), f"Processing chunk {chunkNum}/{len(chunks)}") - - try: - chunkResponse = await aiObjects._callWithModel(model, prompt, chunkData, options) - chunkResults.append(chunkResponse) - logger.info(f"✅ Chunk {chunkNum}/{len(chunks)} processed successfully") - - if progressCallback: - progressCallback(chunkNum / len(chunks), f"Chunk {chunkNum}/{len(chunks)} processed") - except Exception as e: - logger.error(f"❌ Error processing chunk {chunkNum}/{len(chunks)}: {str(e)}") - raise - - # Merge chunk results using unified mergePartResults - # Pass original contentPart to preserve typeGroup for all chunks (one-to-many: 1 part -> N chunks) - mergedContent = self.mergePartResults(chunkResults, options, [contentPart]) - - logger.info(f"✅ Content part chunked and processed with model: {model.name} ({len(chunks)} chunks)") - return AiCallResponse( - content=mergedContent, - modelName=model.name, - priceCHF=sum(r.priceCHF for r in chunkResults), - processingTime=sum(r.processingTime for r in chunkResults), - bytesSent=sum(r.bytesSent for r in chunkResults), - bytesReceived=sum(r.bytesReceived for r in chunkResults), - errorCount=sum(r.errorCount for r in chunkResults) - ) except Exception as e: lastError = e From 1b55db45813a746a7a23cce90cad5bf496c70cf7 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Sun, 8 Feb 2026 13:28:36 +0100 Subject: [PATCH 11/18] admin views fixes --- modules/system/mainSystem.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py index c48add01..01d56eb4 100644 --- a/modules/system/mainSystem.py +++ b/modules/system/mainSystem.py @@ -184,6 +184,15 @@ NAVIGATION_SECTIONS = [ "order": 45, "adminOnly": True, }, + { + "id": "admin-feature-instances", + "objectKey": "ui.admin.featureInstances", + "label": {"en": "Feature Instances", "de": "Feature-Instanzen", "fr": "Instances de features"}, + "icon": "FaCubes", + "path": "/admin/feature-instances", + "order": 48, + "adminOnly": True, + }, { "id": "admin-feature-roles", "objectKey": "ui.admin.featureRoles", From f15ed2e380071c2e060b10eb8d25fbc324755b2f Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Sun, 8 Feb 2026 14:00:08 +0100 Subject: [PATCH 12/18] fixes --- .../automation/interfaceFeatureAutomation.py | 48 +++++---- modules/features/automation/mainAutomation.py | 3 +- .../chatplayground/mainChatplayground.py | 3 +- modules/features/realEstate/mainRealEstate.py | 3 +- modules/features/trustee/mainTrustee.py | 3 +- modules/routes/routeAdminRbacRules.py | 98 +++++++++++++++++++ 6 files changed, 135 insertions(+), 23 deletions(-) diff --git a/modules/features/automation/interfaceFeatureAutomation.py b/modules/features/automation/interfaceFeatureAutomation.py index 2bbf56e0..f88c3973 100644 --- a/modules/features/automation/interfaceFeatureAutomation.py +++ b/modules/features/automation/interfaceFeatureAutomation.py @@ -119,13 +119,12 @@ class AutomationObjects: def _enrichAutomationsWithUserAndMandate(self, automations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Batch enrich automations with user names and mandate names for display. - Uses AppObjects interface to fetch users and mandates with proper access control. + Uses direct DB lookup (no RBAC) because this is purely cosmetic enrichment — + the user already has RBAC-verified access to the automations themselves. """ if not automations: return automations - from modules.interfaces.interfaceDbApp import getInterface as getAppInterface - # Collect all unique user IDs and mandate IDs userIds = set() mandateIds = set() @@ -139,22 +138,33 @@ class AutomationObjects: if mandateId: mandateIds.add(mandateId) - # Use AppObjects interface to fetch users (respects access control) - appInterface = getAppInterface(self.currentUser) - usersMap = {} - if userIds: - for userId in userIds: - user = appInterface.getUser(userId) - if user: - usersMap[userId] = user.username or user.email or userId - - # Use AppObjects interface to fetch mandates (respects access control) - mandatesMap = {} - if mandateIds: - for mandateId in mandateIds: - mandate = appInterface.getMandate(mandateId) - if mandate: - mandatesMap[mandateId] = mandate.name or mandateId + # Use root DB connector for display-only lookups (no RBAC needed) + try: + from modules.datamodels.datamodelUam import UserInDB, Mandate + from modules.security.rootAccess import getRootDbAppConnector + dbAppConn = getRootDbAppConnector() + + # Batch fetch user display names + usersMap = {} + if userIds: + for userId in userIds: + users = dbAppConn.getRecordset(UserInDB, {"id": userId}) + if users: + user = users[0] + fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip() + usersMap[userId] = fullName or user.get("email") or user.get("username") or userId + + # Batch fetch mandate display names + mandatesMap = {} + if mandateIds: + for mandateId in mandateIds: + mandates = dbAppConn.getRecordset(Mandate, {"id": mandateId}) + if mandates: + mandatesMap[mandateId] = mandates[0].get("name") or mandateId + except Exception as e: + logger.warning(f"Could not enrich automations with user/mandate names: {e}") + usersMap = {} + mandatesMap = {} # Enrich each automation with the fetched data for automation in automations: diff --git a/modules/features/automation/mainAutomation.py b/modules/features/automation/mainAutomation.py index 2b8443a9..924f5bc9 100644 --- a/modules/features/automation/mainAutomation.py +++ b/modules/features/automation/mainAutomation.py @@ -266,9 +266,10 @@ def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Di existingRules = rootInterface.getAccessRulesByRole(roleId) # Create a set of existing rule signatures to avoid duplicates + # IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+ existingSignatures = set() for rule in existingRules: - sig = (str(rule.context) if rule.context else None, rule.item) + sig = (rule.context.value if rule.context else None, rule.item) existingSignatures.add(sig) createdCount = 0 diff --git a/modules/features/chatplayground/mainChatplayground.py b/modules/features/chatplayground/mainChatplayground.py index ed0e2868..268ee467 100644 --- a/modules/features/chatplayground/mainChatplayground.py +++ b/modules/features/chatplayground/mainChatplayground.py @@ -230,9 +230,10 @@ def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Di existingRules = rootInterface.getAccessRulesByRole(roleId) # Create a set of existing rule signatures to avoid duplicates + # IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+ existingSignatures = set() for rule in existingRules: - sig = (str(rule.context) if rule.context else None, rule.item) + sig = (rule.context.value if rule.context else None, rule.item) existingSignatures.add(sig) createdCount = 0 diff --git a/modules/features/realEstate/mainRealEstate.py b/modules/features/realEstate/mainRealEstate.py index 8562f5b8..0483218d 100644 --- a/modules/features/realEstate/mainRealEstate.py +++ b/modules/features/realEstate/mainRealEstate.py @@ -245,7 +245,8 @@ def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: list) - from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext existingRules = rootInterface.getAccessRulesByRole(roleId) - existingSignatures = {(str(r.context) if r.context else None, r.item) for r in existingRules} + # IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+ + existingSignatures = {(r.context.value if r.context else None, r.item) for r in existingRules} createdCount = 0 for template in ruleTemplates or []: diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py index ad449d8f..b917f9ad 100644 --- a/modules/features/trustee/mainTrustee.py +++ b/modules/features/trustee/mainTrustee.py @@ -394,9 +394,10 @@ def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Di existingRules = rootInterface.getAccessRulesByRole(roleId) # Create a set of existing rule signatures to avoid duplicates + # IMPORTANT: Use .value for enum comparison, not str() which gives "AccessRuleContext.DATA" in Python 3.11+ existingSignatures = set() for rule in existingRules: - sig = (str(rule.context) if rule.context else None, rule.item) + sig = (rule.context.value if rule.context else None, rule.item) existingSignatures.add(sig) createdCount = 0 diff --git a/modules/routes/routeAdminRbacRules.py b/modules/routes/routeAdminRbacRules.py index fc9b315e..82cc13d7 100644 --- a/modules/routes/routeAdminRbacRules.py +++ b/modules/routes/routeAdminRbacRules.py @@ -1192,3 +1192,101 @@ async def getCatalogStats( status_code=500, detail=f"Failed to get catalog stats: {str(e)}" ) + + +# ============================================================================= +# CLEANUP: Remove duplicate AccessRules +# ============================================================================= + +@router.post("/cleanup/duplicate-rules", response_model=dict) +@limiter.limit("5/minute") +async def cleanup_duplicate_access_rules( + request: Request, + dryRun: bool = Query(True, description="If true, only report duplicates without deleting"), + currentUser: User = Depends(requireSysAdmin) +) -> dict: + """ + Find and remove duplicate AccessRules. + + Duplicates are rules with the same (roleId, context, item) signature. + Only the first rule (oldest) is kept, all others are deleted. + + Query Parameters: + - dryRun: If true (default), only report what would be deleted. Set to false to actually delete. + + Returns: + - Summary with counts and details of duplicates found/removed + """ + try: + rootInterface = getRootInterface() + + # Get ALL AccessRules from DB + allRules = rootInterface.db.getRecordset(AccessRule) + + # Group by signature (roleId, context, item) + rulesBySignature: Dict[tuple, list] = {} + for rule in allRules: + context = rule.get("context", "") + # Normalize context enum value + if hasattr(context, 'value'): + context = context.value + sig = (rule.get("roleId"), str(context), rule.get("item")) + if sig not in rulesBySignature: + rulesBySignature[sig] = [] + rulesBySignature[sig].append(rule) + + # Find duplicates and collect IDs to delete + duplicateGroups = [] + idsToDelete = [] + + for sig, rules in rulesBySignature.items(): + if len(rules) > 1: + # Sort by creation time (keep oldest) + rules.sort(key=lambda r: r.get("_createdAt", 0)) + keepRule = rules[0] + deleteRules = rules[1:] + + duplicateGroups.append({ + "roleId": sig[0], + "context": sig[1], + "item": sig[2] or "(global)", + "totalCount": len(rules), + "keepId": keepRule.get("id"), + "deleteCount": len(deleteRules), + "deleteIds": [r.get("id") for r in deleteRules] + }) + + idsToDelete.extend([r.get("id") for r in deleteRules]) + + # Perform deletion if not dry run + deletedCount = 0 + if not dryRun and idsToDelete: + for ruleId in idsToDelete: + try: + rootInterface.db.recordDelete(AccessRule, ruleId) + deletedCount += 1 + except Exception as e: + logger.warning(f"Failed to delete rule {ruleId}: {e}") + + result = { + "dryRun": dryRun, + "totalRules": len(allRules), + "uniqueSignatures": len(rulesBySignature), + "duplicateGroups": len(duplicateGroups), + "duplicateRulesToDelete": len(idsToDelete), + "deletedCount": deletedCount, + "details": duplicateGroups[:50] # Limit details to 50 groups + } + + logger.info(f"AccessRule cleanup: dryRun={dryRun}, total={len(allRules)}, " + f"duplicateGroups={len(duplicateGroups)}, toDelete={len(idsToDelete)}, " + f"deleted={deletedCount}") + + return result + + except Exception as e: + logger.error(f"Error during AccessRule cleanup: {str(e)}") + raise HTTPException( + status_code=500, + detail=f"Failed to cleanup duplicate rules: {str(e)}" + ) From 8d28f6d77b3294b387c170626c5f0c3e32e811df Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Sun, 8 Feb 2026 14:26:01 +0100 Subject: [PATCH 13/18] fiixed feature instance role access --- .../automation/interfaceFeatureAutomation.py | 19 +- .../automation/routeFeatureAutomation.py | 28 +- .../features/chatbot/routeFeatureChatbot.py | 14 +- .../routeFeatureChatplayground.py | 14 +- .../neutralization/routeFeatureNeutralizer.py | 16 +- .../realEstate/routeFeatureRealEstate.py | 58 +-- .../features/trustee/routeFeatureTrustee.py | 232 +++++------ modules/routes/routeAdmin.py | 8 +- modules/routes/routeAdminAutomationEvents.py | 4 +- modules/routes/routeAdminFeatures.py | 34 +- modules/routes/routeAdminRbacExport.py | 4 +- modules/routes/routeAdminRbacRoles.py | 24 +- modules/routes/routeAdminRbacRules.py | 34 +- .../routes/routeAdminUserAccessOverview.py | 6 +- modules/routes/routeAttributes.py | 4 +- modules/routes/routeBilling.py | 32 +- modules/routes/routeDataConnections.py | 14 +- modules/routes/routeDataFiles.py | 14 +- modules/routes/routeDataMandates.py | 18 +- modules/routes/routeDataPrompts.py | 10 +- modules/routes/routeDataUsers.py | 18 +- modules/routes/routeDataWorkflows.py | 24 +- modules/routes/routeGdpr.py | 8 +- modules/routes/routeInvitations.py | 10 +- modules/routes/routeMessaging.py | 28 +- modules/routes/routeNotifications.py | 16 +- modules/routes/routeSecurityAdmin.py | 20 +- modules/routes/routeSecurityGoogle.py | 8 +- modules/routes/routeSecurityLocal.py | 16 +- modules/routes/routeSecurityMsft.py | 12 +- modules/routes/routeSystem.py | 2 +- modules/workflows/automation/mainWorkflow.py | 14 +- scripts/migrate_async_to_sync.py | 377 ++++++++++++++++++ 33 files changed, 764 insertions(+), 376 deletions(-) create mode 100644 scripts/migrate_async_to_sync.py diff --git a/modules/features/automation/interfaceFeatureAutomation.py b/modules/features/automation/interfaceFeatureAutomation.py index f88c3973..770d5eb0 100644 --- a/modules/features/automation/interfaceFeatureAutomation.py +++ b/modules/features/automation/interfaceFeatureAutomation.py @@ -88,7 +88,9 @@ class AutomationObjects: permissions = self.rbac.getUserPermissions( user=self.currentUser, context=AccessRuleContext.DATA, - item=objectKey + item=objectKey, + mandateId=self.mandateId, + featureInstanceId=self.featureInstanceId ) accessLevel = getattr(permissions, action, AccessLevel.NONE) @@ -373,6 +375,21 @@ class AutomationObjects: logger.error(f"Error creating automation definition: {str(e)}") raise + def _saveExecutionLog(self, automationId: str, executionLogs: List[Dict[str, Any]]) -> None: + """ + Save execution logs to an automation definition WITHOUT RBAC check. + + This is a system-level operation: when a user executes an automation, + the execution log must be saved regardless of whether the user has + 'update' permission on the AutomationDefinition. The user already + proved they have execute/read access by loading the automation. + """ + try: + self.db.recordModify(AutomationDefinition, automationId, {"executionLogs": executionLogs}) + logger.debug(f"Saved execution log for automation {automationId}") + except Exception as e: + logger.warning(f"Could not save execution log for automation {automationId}: {e}") + def updateAutomationDefinition(self, automationId: str, automationData: Dict[str, Any]) -> AutomationDefinition: """Updates an automation definition, then triggers sync.""" try: diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py index d9c0b758..d39a3358 100644 --- a/modules/features/automation/routeFeatureAutomation.py +++ b/modules/features/automation/routeFeatureAutomation.py @@ -42,7 +42,7 @@ router = APIRouter( @router.get("", response_model=PaginatedResponse[AutomationDefinition]) @limiter.limit("30/minute") -async def get_automations( +def get_automations( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), context: RequestContext = Depends(getRequestContext) @@ -107,7 +107,7 @@ async def get_automations( @router.post("", response_model=AutomationDefinition) @limiter.limit("10/minute") -async def create_automation( +def create_automation( request: Request, automation: AutomationDefinition, context: RequestContext = Depends(getRequestContext) @@ -128,7 +128,7 @@ async def create_automation( ) @router.get("/attributes", response_model=Dict[str, Any]) -async def get_automation_attributes( +def get_automation_attributes( request: Request ) -> Dict[str, Any]: """Get attribute definitions for AutomationDefinition model""" @@ -137,7 +137,7 @@ async def get_automation_attributes( @router.get("/actions") @limiter.limit("30/minute") -async def get_available_actions( +def get_available_actions( request: Request, context: RequestContext = Depends(getRequestContext) ) -> JSONResponse: @@ -230,7 +230,7 @@ async def get_available_actions( @router.get("/{automationId}", response_model=AutomationDefinition) @limiter.limit("30/minute") -async def get_automation( +def get_automation( request: Request, automationId: str = Path(..., description="Automation ID"), context: RequestContext = Depends(getRequestContext) @@ -257,7 +257,7 @@ async def get_automation( @router.put("/{automationId}", response_model=AutomationDefinition) @limiter.limit("10/minute") -async def update_automation( +def update_automation( request: Request, automationId: str = Path(..., description="Automation ID"), automation: AutomationDefinition = Body(...), @@ -285,7 +285,7 @@ async def update_automation( @router.patch("/{automationId}/status") @limiter.limit("30/minute") -async def update_automation_status( +def update_automation_status( request: Request, automationId: str = Path(..., description="Automation ID"), active: bool = Body(..., embed=True), @@ -326,7 +326,7 @@ async def update_automation_status( @router.delete("/{automationId}") @limiter.limit("10/minute") -async def delete_automation( +def delete_automation( request: Request, automationId: str = Path(..., description="Automation ID"), context: RequestContext = Depends(getRequestContext) @@ -407,7 +407,7 @@ templateAttributes = getModelAttributeDefinitions(AutomationTemplate) @templateRouter.get("", response_model=PaginatedResponse[AutomationTemplate]) @limiter.limit("30/minute") -async def get_db_templates( +def get_db_templates( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), context: RequestContext = Depends(getRequestContext) @@ -470,7 +470,7 @@ async def get_db_templates( @templateRouter.get("/attributes", response_model=Dict[str, Any]) -async def get_template_attributes( +def get_template_attributes( request: Request ) -> Dict[str, Any]: """Get attribute definitions for AutomationTemplate model""" @@ -479,7 +479,7 @@ async def get_template_attributes( @templateRouter.get("/{templateId}") @limiter.limit("30/minute") -async def get_db_template( +def get_db_template( request: Request, templateId: str = Path(..., description="Template ID"), context: RequestContext = Depends(getRequestContext) @@ -511,7 +511,7 @@ async def get_db_template( @templateRouter.post("") @limiter.limit("10/minute") -async def create_db_template( +def create_db_template( request: Request, templateData: Dict[str, Any] = Body(...), context: RequestContext = Depends(getRequestContext) @@ -542,7 +542,7 @@ async def create_db_template( @templateRouter.put("/{templateId}") @limiter.limit("10/minute") -async def update_db_template( +def update_db_template( request: Request, templateId: str = Path(..., description="Template ID"), templateData: Dict[str, Any] = Body(...), @@ -574,7 +574,7 @@ async def update_db_template( @templateRouter.delete("/{templateId}") @limiter.limit("10/minute") -async def delete_db_template( +def delete_db_template( request: Request, templateId: str = Path(..., description="Template ID"), context: RequestContext = Depends(getRequestContext) diff --git a/modules/features/chatbot/routeFeatureChatbot.py b/modules/features/chatbot/routeFeatureChatbot.py index 62dc02f9..d8b4dd70 100644 --- a/modules/features/chatbot/routeFeatureChatbot.py +++ b/modules/features/chatbot/routeFeatureChatbot.py @@ -55,7 +55,7 @@ def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None): ) -async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: +def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: """ Validate that the user has access to the feature instance. Returns the mandateId for the instance. @@ -124,7 +124,7 @@ async def stream_chatbot_start( - Query parameter takes precedence if both are provided """ # Validate instance access - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) event_manager = get_event_manager() @@ -323,7 +323,7 @@ async def stop_chatbot( ) -> ChatWorkflow: """Stops a running chatbot workflow.""" # Validate instance access - await _validateInstanceAccess(instanceId, context) + _validateInstanceAccess(instanceId, context) try: # Get chatbot interface with instance context @@ -392,7 +392,7 @@ async def stop_chatbot( # to prevent "threads" from being matched as a workflowId @router.get("/{instanceId}/threads") @limiter.limit("120/minute") -async def get_chatbot_threads( +def get_chatbot_threads( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), workflowId: Optional[str] = Query(None, description="Optional workflow ID to get details and chat data for a specific thread"), @@ -406,7 +406,7 @@ async def get_chatbot_threads( - If workflowId is not provided: Returns a paginated list of all workflows """ # Validate instance access - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) try: interfaceDbChat = _getServiceChat(context, instanceId) @@ -523,7 +523,7 @@ async def get_chatbot_threads( # NOTE: This catch-all route MUST be defined AFTER more specific routes like /threads @router.delete("/{instanceId}/{workflowId}", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def delete_chatbot( +def delete_chatbot( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), workflowId: str = Path(..., description="ID of the workflow to delete"), @@ -531,7 +531,7 @@ async def delete_chatbot( ) -> Dict[str, Any]: """Deletes a chatbot workflow and its associated data.""" # Validate instance access - if user has access to instance, they can delete their workflows - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) try: # Get service center diff --git a/modules/features/chatplayground/routeFeatureChatplayground.py b/modules/features/chatplayground/routeFeatureChatplayground.py index cedd05d6..ce1611ea 100644 --- a/modules/features/chatplayground/routeFeatureChatplayground.py +++ b/modules/features/chatplayground/routeFeatureChatplayground.py @@ -41,7 +41,7 @@ def _getServiceChat(context: RequestContext, featureInstanceId: str = None): ) -async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: +def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: """ Validate that user has access to the feature instance. @@ -93,7 +93,7 @@ async def start_workflow( """ try: # Validate access and get mandate ID - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) # Start or continue workflow workflow = await chatStart( @@ -129,7 +129,7 @@ async def stop_workflow( """Stops a running workflow.""" try: # Validate access and get mandate ID - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) # Stop workflow (pass featureInstanceId for proper RBAC filtering) workflow = await chatStop( @@ -154,7 +154,7 @@ async def stop_workflow( # Unified Chat Data Endpoint for Polling @router.get("/{instanceId}/{workflowId}/chatData") @limiter.limit("120/minute") -async def get_workflow_chat_data( +def get_workflow_chat_data( request: Request, instanceId: str = Path(..., description="Feature instance ID"), workflowId: str = Path(..., description="ID of the workflow"), @@ -167,7 +167,7 @@ async def get_workflow_chat_data( """ try: # Validate access - await _validateInstanceAccess(instanceId, context) + _validateInstanceAccess(instanceId, context) # Get service with feature instance context chatInterface = _getServiceChat(context, featureInstanceId=instanceId) @@ -198,7 +198,7 @@ async def get_workflow_chat_data( # Get workflows for this instance @router.get("/{instanceId}/workflows") @limiter.limit("120/minute") -async def get_workflows( +def get_workflows( request: Request, instanceId: str = Path(..., description="Feature instance ID"), page: int = Query(1, ge=1, description="Page number"), @@ -210,7 +210,7 @@ async def get_workflows( """ try: # Validate access - await _validateInstanceAccess(instanceId, context) + _validateInstanceAccess(instanceId, context) # Get service with feature instance context chatInterface = _getServiceChat(context, featureInstanceId=instanceId) diff --git a/modules/features/neutralization/routeFeatureNeutralizer.py b/modules/features/neutralization/routeFeatureNeutralizer.py index be262e47..33b9a00d 100644 --- a/modules/features/neutralization/routeFeatureNeutralizer.py +++ b/modules/features/neutralization/routeFeatureNeutralizer.py @@ -29,7 +29,7 @@ router = APIRouter( @router.get("/config", response_model=DataNeutraliserConfig) @limiter.limit("30/minute") -async def get_neutralization_config( +def get_neutralization_config( request: Request, context: RequestContext = Depends(getRequestContext) ) -> DataNeutraliserConfig: @@ -62,7 +62,7 @@ async def get_neutralization_config( @router.post("/config", response_model=DataNeutraliserConfig) @limiter.limit("10/minute") -async def save_neutralization_config( +def save_neutralization_config( request: Request, config_data: Dict[str, Any] = Body(...), context: RequestContext = Depends(getRequestContext) @@ -83,7 +83,7 @@ async def save_neutralization_config( @router.post("/neutralize-text", response_model=Dict[str, Any]) @limiter.limit("20/minute") -async def neutralize_text( +def neutralize_text( request: Request, text_data: Dict[str, Any] = Body(...), context: RequestContext = Depends(getRequestContext) @@ -115,7 +115,7 @@ async def neutralize_text( @router.post("/resolve-text", response_model=Dict[str, str]) @limiter.limit("20/minute") -async def resolve_text( +def resolve_text( request: Request, text_data: Dict[str, str] = Body(...), context: RequestContext = Depends(getRequestContext) @@ -146,7 +146,7 @@ async def resolve_text( @router.get("/attributes", response_model=List[DataNeutralizerAttributes]) @limiter.limit("30/minute") -async def get_neutralization_attributes( +def get_neutralization_attributes( request: Request, fileId: Optional[str] = Query(None, description="Filter by file ID"), context: RequestContext = Depends(getRequestContext) @@ -199,7 +199,7 @@ async def process_sharepoint_files( @router.post("/batch-process", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def batch_process_files( +def batch_process_files( request: Request, files_data: List[Dict[str, Any]] = Body(...), context: RequestContext = Depends(getRequestContext) @@ -228,7 +228,7 @@ async def batch_process_files( @router.get("/stats", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def get_neutralization_stats( +def get_neutralization_stats( request: Request, context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: @@ -248,7 +248,7 @@ async def get_neutralization_stats( @router.delete("/attributes/{fileId}", response_model=Dict[str, str]) @limiter.limit("10/minute") -async def cleanup_file_attributes( +def cleanup_file_attributes( request: Request, fileId: str = Path(..., description="File ID to cleanup attributes for"), context: RequestContext = Depends(getRequestContext) diff --git a/modules/features/realEstate/routeFeatureRealEstate.py b/modules/features/realEstate/routeFeatureRealEstate.py index 09f28f13..18fecd73 100644 --- a/modules/features/realEstate/routeFeatureRealEstate.py +++ b/modules/features/realEstate/routeFeatureRealEstate.py @@ -83,7 +83,7 @@ def _parsePagination(pagination: Optional[str]) -> Optional[PaginationParams]: return None -async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: +def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: """ Validate that the user has access to the feature instance. Returns the mandateId for the instance. @@ -132,14 +132,14 @@ _REALESTATE_ENTITY_MODELS = { @router.get("/{instanceId}/attributes/{entityType}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def get_entity_attributes( +def get_entity_attributes( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), entityType: str = Path(..., description="Entity type (e.g., Projekt, Parzelle)"), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Get attribute definitions for a Real Estate entity. Used by FormGeneratorTable.""" - await _validateInstanceAccess(instanceId, context) + _validateInstanceAccess(instanceId, context) if entityType not in _REALESTATE_ENTITY_MODELS: raise HTTPException( status_code=404, @@ -163,13 +163,13 @@ async def get_entity_attributes( @router.get("/{instanceId}/projects/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_project_options( +def get_project_options( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: """Get project options for select dropdowns. Returns: [{ value, label }]""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -179,13 +179,13 @@ async def get_project_options( @router.get("/{instanceId}/parcels/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_parcel_options( +def get_parcel_options( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: """Get parcel options for select dropdowns. Returns: [{ value, label }]""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -197,14 +197,14 @@ async def get_parcel_options( @router.get("/{instanceId}/projects", response_model=PaginatedResponse[Projekt]) @limiter.limit("30/minute") -async def get_projects( +def get_projects( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[Projekt]: """Get all projects for a feature instance with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -241,14 +241,14 @@ async def get_projects( @router.get("/{instanceId}/projects/{projectId}", response_model=Projekt) @limiter.limit("30/minute") -async def get_project_by_id( +def get_project_by_id( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), projectId: str = Path(..., description="Project ID"), context: RequestContext = Depends(getRequestContext) ) -> Projekt: """Get a single project by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -260,14 +260,14 @@ async def get_project_by_id( @router.post("/{instanceId}/projects", response_model=Projekt) @limiter.limit("30/minute") -async def create_project( +def create_project( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: Dict[str, Any] = Body(...), context: RequestContext = Depends(getRequestContext) ) -> Projekt: """Create a new project.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -284,7 +284,7 @@ async def create_project( @router.put("/{instanceId}/projects/{projectId}", response_model=Projekt) @limiter.limit("30/minute") -async def update_project( +def update_project( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), projectId: str = Path(..., description="Project ID"), @@ -292,7 +292,7 @@ async def update_project( context: RequestContext = Depends(getRequestContext) ) -> Projekt: """Update a project.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -307,14 +307,14 @@ async def update_project( @router.delete("/{instanceId}/projects/{projectId}", status_code=status.HTTP_204_NO_CONTENT) @limiter.limit("30/minute") -async def delete_project( +def delete_project( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), projectId: str = Path(..., description="Project ID"), context: RequestContext = Depends(getRequestContext) ) -> None: """Delete a project.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -329,14 +329,14 @@ async def delete_project( @router.get("/{instanceId}/parcels", response_model=PaginatedResponse[Parzelle]) @limiter.limit("30/minute") -async def get_parcels( +def get_parcels( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[Parzelle]: """Get all parcels for a feature instance with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -373,14 +373,14 @@ async def get_parcels( @router.get("/{instanceId}/parcels/{parcelId}", response_model=Parzelle) @limiter.limit("30/minute") -async def get_parcel_by_id( +def get_parcel_by_id( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), parcelId: str = Path(..., description="Parcel ID"), context: RequestContext = Depends(getRequestContext) ) -> Parzelle: """Get a single parcel by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -392,14 +392,14 @@ async def get_parcel_by_id( @router.post("/{instanceId}/parcels", response_model=Parzelle) @limiter.limit("30/minute") -async def create_parcel( +def create_parcel( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: Dict[str, Any] = Body(...), context: RequestContext = Depends(getRequestContext) ) -> Parzelle: """Create a new parcel.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -416,7 +416,7 @@ async def create_parcel( @router.put("/{instanceId}/parcels/{parcelId}", response_model=Parzelle) @limiter.limit("30/minute") -async def update_parcel( +def update_parcel( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), parcelId: str = Path(..., description="Parcel ID"), @@ -424,7 +424,7 @@ async def update_parcel( context: RequestContext = Depends(getRequestContext) ) -> Parzelle: """Update a parcel.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -439,14 +439,14 @@ async def update_parcel( @router.delete("/{instanceId}/parcels/{parcelId}", status_code=status.HTTP_204_NO_CONTENT) @limiter.limit("30/minute") -async def delete_parcel( +def delete_parcel( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), parcelId: str = Path(..., description="Parcel ID"), context: RequestContext = Depends(getRequestContext) ) -> None: """Delete a parcel.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getRealEstateInterface( context.user, mandateId=mandateId, featureInstanceId=instanceId ) @@ -549,7 +549,7 @@ async def process_command( @router.get("/tables", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def get_available_tables( +def get_available_tables( request: Request, context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: @@ -645,7 +645,7 @@ async def get_available_tables( @router.get("/table/{table}", response_model=PaginatedResponse[Any]) @limiter.limit("120/minute") -async def get_table_data( +def get_table_data( request: Request, table: str = Path(..., description="Table name (Projekt, Parzelle, Dokument, Gemeinde, Kanton, Land)"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), diff --git a/modules/features/trustee/routeFeatureTrustee.py b/modules/features/trustee/routeFeatureTrustee.py index 43706a10..408871c5 100644 --- a/modules/features/trustee/routeFeatureTrustee.py +++ b/modules/features/trustee/routeFeatureTrustee.py @@ -66,7 +66,7 @@ def _parsePagination(pagination: Optional[str]) -> Optional[PaginationParams]: return None -async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: +def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: """ Validate that the user has access to the feature instance. Returns the mandateId for the instance. @@ -134,7 +134,7 @@ _TRUSTEE_ENTITY_MODELS = { @router.get("/{instanceId}/attributes/{entityType}") @limiter.limit("30/minute") -async def get_entity_attributes( +def get_entity_attributes( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), entityType: str = Path(..., description="Entity type (e.g., TrusteeDocument)"), @@ -145,7 +145,7 @@ async def get_entity_attributes( Used by FormGeneratorTable for dynamic column generation. """ # Validate instance access - await _validateInstanceAccess(instanceId, context) + _validateInstanceAccess(instanceId, context) # Check if entity type is valid if entityType not in _TRUSTEE_ENTITY_MODELS: @@ -182,7 +182,7 @@ async def get_entity_attributes( @router.get("/mime-types/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_mime_type_options( +def get_mime_type_options( request: Request, context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: @@ -217,13 +217,13 @@ async def get_mime_type_options( @router.get("/{instanceId}/organisations/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_organisation_options( +def get_organisation_options( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: """Get organisation options for select dropdowns. Returns: [{ value, label }]""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllOrganisations(None) items = result.items if hasattr(result, 'items') else result @@ -232,13 +232,13 @@ async def get_organisation_options( @router.get("/{instanceId}/roles/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_role_options( +def get_role_options( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: """Get role options for select dropdowns. Returns: [{ value, label }]""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllRoles(None) items = result.items if hasattr(result, 'items') else result @@ -247,7 +247,7 @@ async def get_role_options( @router.get("/{instanceId}/contracts/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_contract_options( +def get_contract_options( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), organisationId: Optional[str] = Query(None, description="Optional: Filter by organisation ID"), @@ -261,7 +261,7 @@ async def get_contract_options( Returns: [{ value, label }] """ - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) if organisationId: @@ -277,13 +277,13 @@ async def get_contract_options( @router.get("/{instanceId}/documents/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_document_options( +def get_document_options( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: """Get document options for select dropdowns. Returns: [{ id, value, label }]""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllDocuments(None) items = result.items if hasattr(result, 'items') else result @@ -293,13 +293,13 @@ async def get_document_options( @router.get("/{instanceId}/positions/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_position_options( +def get_position_options( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: """Get position options for select dropdowns. Returns: [{ id, value, label }]""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllPositions(None) items = result.items if hasattr(result, 'items') else result @@ -326,14 +326,14 @@ async def get_position_options( @router.get("/{instanceId}/organisations", response_model=PaginatedResponse[TrusteeOrganisation]) @limiter.limit("30/minute") -async def get_organisations( +def get_organisations( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[TrusteeOrganisation]: """Get all organisations for a feature instance with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) paginationParams = _parsePagination(pagination) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) @@ -356,14 +356,14 @@ async def get_organisations( @router.get("/{instanceId}/organisations/{orgId}", response_model=TrusteeOrganisation) @limiter.limit("30/minute") -async def get_organisation( +def get_organisation( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), orgId: str = Path(..., description="Organisation ID"), context: RequestContext = Depends(getRequestContext) ) -> TrusteeOrganisation: """Get a single organisation by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) org = interface.getOrganisation(orgId) @@ -374,14 +374,14 @@ async def get_organisation( @router.post("/{instanceId}/organisations", response_model=TrusteeOrganisation, status_code=201) @limiter.limit("10/minute") -async def create_organisation( +def create_organisation( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: TrusteeOrganisation = Body(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteeOrganisation: """Create a new organisation.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.createOrganisation(data.model_dump()) @@ -392,7 +392,7 @@ async def create_organisation( @router.put("/{instanceId}/organisations/{orgId}", response_model=TrusteeOrganisation) @limiter.limit("10/minute") -async def update_organisation( +def update_organisation( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), orgId: str = Path(..., description="Organisation ID"), @@ -400,7 +400,7 @@ async def update_organisation( context: RequestContext = Depends(getRequestContext) ) -> TrusteeOrganisation: """Update an organisation.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getOrganisation(orgId) @@ -415,14 +415,14 @@ async def update_organisation( @router.delete("/{instanceId}/organisations/{orgId}") @limiter.limit("10/minute") -async def delete_organisation( +def delete_organisation( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), orgId: str = Path(..., description="Organisation ID"), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Delete an organisation.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getOrganisation(orgId) @@ -439,14 +439,14 @@ async def delete_organisation( @router.get("/{instanceId}/roles", response_model=PaginatedResponse[TrusteeRole]) @limiter.limit("30/minute") -async def get_roles( +def get_roles( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[TrusteeRole]: """Get all roles with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) paginationParams = _parsePagination(pagination) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) @@ -469,14 +469,14 @@ async def get_roles( @router.get("/{instanceId}/roles/{roleId}", response_model=TrusteeRole) @limiter.limit("30/minute") -async def get_role( +def get_role( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(..., description="Role ID"), context: RequestContext = Depends(getRequestContext) ) -> TrusteeRole: """Get a single role by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) role = interface.getRole(roleId) @@ -487,14 +487,14 @@ async def get_role( @router.post("/{instanceId}/roles", response_model=TrusteeRole, status_code=201) @limiter.limit("10/minute") -async def create_role( +def create_role( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: TrusteeRole = Body(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteeRole: """Create a new role (sysadmin only).""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.createRole(data.model_dump()) @@ -505,7 +505,7 @@ async def create_role( @router.put("/{instanceId}/roles/{roleId}", response_model=TrusteeRole) @limiter.limit("10/minute") -async def update_role( +def update_role( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(...), @@ -513,7 +513,7 @@ async def update_role( context: RequestContext = Depends(getRequestContext) ) -> TrusteeRole: """Update a role (sysadmin only).""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getRole(roleId) @@ -528,14 +528,14 @@ async def update_role( @router.delete("/{instanceId}/roles/{roleId}") @limiter.limit("10/minute") -async def delete_role( +def delete_role( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Delete a role (sysadmin only, fails if in use).""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getRole(roleId) @@ -552,14 +552,14 @@ async def delete_role( @router.get("/{instanceId}/access", response_model=PaginatedResponse[TrusteeAccess]) @limiter.limit("30/minute") -async def get_all_access( +def get_all_access( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[TrusteeAccess]: """Get all access records with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) paginationParams = _parsePagination(pagination) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) @@ -582,14 +582,14 @@ async def get_all_access( @router.get("/{instanceId}/access/{accessId}", response_model=TrusteeAccess) @limiter.limit("30/minute") -async def get_access( +def get_access( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), accessId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteeAccess: """Get a single access record by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) access = interface.getAccess(accessId) @@ -600,14 +600,14 @@ async def get_access( @router.get("/{instanceId}/access/organisation/{orgId}", response_model=List[TrusteeAccess]) @limiter.limit("30/minute") -async def get_access_by_organisation( +def get_access_by_organisation( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), orgId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteeAccess]: """Get all access records for an organisation.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getAccessByOrganisation(orgId) @@ -615,14 +615,14 @@ async def get_access_by_organisation( @router.get("/{instanceId}/access/user/{userId}", response_model=List[TrusteeAccess]) @limiter.limit("30/minute") -async def get_access_by_user( +def get_access_by_user( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), userId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteeAccess]: """Get all access records for a user.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getAccessByUser(userId) @@ -630,14 +630,14 @@ async def get_access_by_user( @router.post("/{instanceId}/access", response_model=TrusteeAccess, status_code=201) @limiter.limit("10/minute") -async def create_access( +def create_access( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: TrusteeAccess = Body(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteeAccess: """Create a new access record.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.createAccess(data.model_dump()) @@ -648,7 +648,7 @@ async def create_access( @router.put("/{instanceId}/access/{accessId}", response_model=TrusteeAccess) @limiter.limit("10/minute") -async def update_access( +def update_access( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), accessId: str = Path(...), @@ -656,7 +656,7 @@ async def update_access( context: RequestContext = Depends(getRequestContext) ) -> TrusteeAccess: """Update an access record.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getAccess(accessId) @@ -671,14 +671,14 @@ async def update_access( @router.delete("/{instanceId}/access/{accessId}") @limiter.limit("10/minute") -async def delete_access( +def delete_access( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), accessId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Delete an access record.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getAccess(accessId) @@ -695,14 +695,14 @@ async def delete_access( @router.get("/{instanceId}/contracts", response_model=PaginatedResponse[TrusteeContract]) @limiter.limit("30/minute") -async def get_contracts( +def get_contracts( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[TrusteeContract]: """Get all contracts with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) paginationParams = _parsePagination(pagination) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) @@ -725,14 +725,14 @@ async def get_contracts( @router.get("/{instanceId}/contracts/{contractId}", response_model=TrusteeContract) @limiter.limit("30/minute") -async def get_contract( +def get_contract( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), contractId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteeContract: """Get a single contract by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) contract = interface.getContract(contractId) @@ -743,14 +743,14 @@ async def get_contract( @router.get("/{instanceId}/contracts/organisation/{orgId}", response_model=List[TrusteeContract]) @limiter.limit("30/minute") -async def get_contracts_by_organisation( +def get_contracts_by_organisation( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), orgId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteeContract]: """Get all contracts for an organisation.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getContractsByOrganisation(orgId) @@ -758,14 +758,14 @@ async def get_contracts_by_organisation( @router.post("/{instanceId}/contracts", response_model=TrusteeContract, status_code=201) @limiter.limit("10/minute") -async def create_contract( +def create_contract( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: TrusteeContract = Body(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteeContract: """Create a new contract.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.createContract(data.model_dump()) @@ -776,7 +776,7 @@ async def create_contract( @router.put("/{instanceId}/contracts/{contractId}", response_model=TrusteeContract) @limiter.limit("10/minute") -async def update_contract( +def update_contract( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), contractId: str = Path(...), @@ -784,7 +784,7 @@ async def update_contract( context: RequestContext = Depends(getRequestContext) ) -> TrusteeContract: """Update a contract (organisationId is immutable).""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getContract(contractId) @@ -799,14 +799,14 @@ async def update_contract( @router.delete("/{instanceId}/contracts/{contractId}") @limiter.limit("10/minute") -async def delete_contract( +def delete_contract( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), contractId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Delete a contract.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getContract(contractId) @@ -823,14 +823,14 @@ async def delete_contract( @router.get("/{instanceId}/documents", response_model=PaginatedResponse[TrusteeDocument]) @limiter.limit("30/minute") -async def get_documents( +def get_documents( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[TrusteeDocument]: """Get all documents (metadata only) with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) paginationParams = _parsePagination(pagination) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) @@ -853,14 +853,14 @@ async def get_documents( @router.get("/{instanceId}/documents/{documentId}", response_model=TrusteeDocument) @limiter.limit("30/minute") -async def get_document( +def get_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), documentId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteeDocument: """Get document metadata by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) doc = interface.getDocument(documentId) @@ -871,14 +871,14 @@ async def get_document( @router.get("/{instanceId}/documents/{documentId}/data") @limiter.limit("10/minute") -async def get_document_data( +def get_document_data( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), documentId: str = Path(...), context: RequestContext = Depends(getRequestContext) ): """Download document binary data.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) doc = interface.getDocument(documentId) @@ -898,14 +898,14 @@ async def get_document_data( @router.get("/{instanceId}/documents/contract/{contractId}", response_model=List[TrusteeDocument]) @limiter.limit("30/minute") -async def get_documents_by_contract( +def get_documents_by_contract( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), contractId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteeDocument]: """Get all documents for a contract.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getDocumentsByContract(contractId) @@ -919,7 +919,7 @@ async def create_document( context: RequestContext = Depends(getRequestContext) ) -> TrusteeDocument: """Create a new document. Accepts JSON body with optional base64-encoded documentData.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) # Parse JSON body body = await request.json() @@ -959,7 +959,7 @@ async def upload_document( context: RequestContext = Depends(getRequestContext) ) -> TrusteeDocument: """Upload a document with multipart/form-data.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) # Read file content fileContent = await file.read() @@ -980,7 +980,7 @@ async def upload_document( @router.put("/{instanceId}/documents/{documentId}", response_model=TrusteeDocument) @limiter.limit("10/minute") -async def update_document( +def update_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), documentId: str = Path(...), @@ -988,7 +988,7 @@ async def update_document( context: RequestContext = Depends(getRequestContext) ) -> TrusteeDocument: """Update document metadata.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getDocument(documentId) @@ -1003,14 +1003,14 @@ async def update_document( @router.delete("/{instanceId}/documents/{documentId}") @limiter.limit("10/minute") -async def delete_document( +def delete_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), documentId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Delete a document.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getDocument(documentId) @@ -1027,14 +1027,14 @@ async def delete_document( @router.get("/{instanceId}/positions", response_model=PaginatedResponse[TrusteePosition]) @limiter.limit("30/minute") -async def get_positions( +def get_positions( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None), context: RequestContext = Depends(getRequestContext) ) -> PaginatedResponse[TrusteePosition]: """Get all positions with optional pagination.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) paginationParams = _parsePagination(pagination) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) @@ -1057,14 +1057,14 @@ async def get_positions( @router.get("/{instanceId}/positions/{positionId}", response_model=TrusteePosition) @limiter.limit("30/minute") -async def get_position( +def get_position( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), positionId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteePosition: """Get a single position by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) position = interface.getPosition(positionId) @@ -1075,14 +1075,14 @@ async def get_position( @router.get("/{instanceId}/positions/contract/{contractId}", response_model=List[TrusteePosition]) @limiter.limit("30/minute") -async def get_positions_by_contract( +def get_positions_by_contract( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), contractId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteePosition]: """Get all positions for a contract.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getPositionsByContract(contractId) @@ -1090,14 +1090,14 @@ async def get_positions_by_contract( @router.get("/{instanceId}/positions/organisation/{orgId}", response_model=List[TrusteePosition]) @limiter.limit("30/minute") -async def get_positions_by_organisation( +def get_positions_by_organisation( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), orgId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteePosition]: """Get all positions for an organisation.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getPositionsByOrganisation(orgId) @@ -1105,14 +1105,14 @@ async def get_positions_by_organisation( @router.post("/{instanceId}/positions", response_model=TrusteePosition, status_code=201) @limiter.limit("10/minute") -async def create_position( +def create_position( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: TrusteePosition = Body(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteePosition: """Create a new position.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.createPosition(data.model_dump()) @@ -1123,7 +1123,7 @@ async def create_position( @router.put("/{instanceId}/positions/{positionId}", response_model=TrusteePosition) @limiter.limit("10/minute") -async def update_position( +def update_position( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), positionId: str = Path(...), @@ -1131,7 +1131,7 @@ async def update_position( context: RequestContext = Depends(getRequestContext) ) -> TrusteePosition: """Update a position.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getPosition(positionId) @@ -1146,14 +1146,14 @@ async def update_position( @router.delete("/{instanceId}/positions/{positionId}") @limiter.limit("10/minute") -async def delete_position( +def delete_position( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), positionId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Delete a position.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getPosition(positionId) @@ -1170,7 +1170,7 @@ async def delete_position( @router.get("/{instanceId}/position-documents") @limiter.limit("30/minute") -async def get_position_documents( +def get_position_documents( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), pagination: Optional[str] = Query(None), @@ -1180,7 +1180,7 @@ async def get_position_documents( Each item includes _permissions: { canUpdate, canDelete } for row-level permission UI. """ - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) paginationParams = _parsePagination(pagination) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) @@ -1203,14 +1203,14 @@ async def get_position_documents( @router.get("/{instanceId}/position-documents/{linkId}", response_model=TrusteePositionDocument) @limiter.limit("30/minute") -async def get_position_document( +def get_position_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), linkId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteePositionDocument: """Get a single position-document link by ID.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) link = interface.getPositionDocument(linkId) @@ -1221,14 +1221,14 @@ async def get_position_document( @router.get("/{instanceId}/position-documents/position/{positionId}", response_model=List[TrusteePositionDocument]) @limiter.limit("30/minute") -async def get_documents_for_position( +def get_documents_for_position( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), positionId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteePositionDocument]: """Get all document links for a position.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getDocumentsForPosition(positionId) @@ -1236,14 +1236,14 @@ async def get_documents_for_position( @router.get("/{instanceId}/position-documents/document/{documentId}", response_model=List[TrusteePositionDocument]) @limiter.limit("30/minute") -async def get_positions_for_document( +def get_positions_for_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), documentId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> List[TrusteePositionDocument]: """Get all position links for a document.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) return interface.getPositionsForDocument(documentId) @@ -1251,14 +1251,14 @@ async def get_positions_for_document( @router.post("/{instanceId}/position-documents", response_model=TrusteePositionDocument, status_code=201) @limiter.limit("10/minute") -async def create_position_document( +def create_position_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), data: TrusteePositionDocument = Body(...), context: RequestContext = Depends(getRequestContext) ) -> TrusteePositionDocument: """Create a new position-document link.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.createPositionDocument(data.model_dump()) @@ -1269,7 +1269,7 @@ async def create_position_document( @router.put("/{instanceId}/position-documents/{linkId}", response_model=TrusteePositionDocument) @limiter.limit("10/minute") -async def update_position_document( +def update_position_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), linkId: str = Path(...), @@ -1277,7 +1277,7 @@ async def update_position_document( context: RequestContext = Depends(getRequestContext) ) -> TrusteePositionDocument: """Update a position-document link.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.updatePositionDocument(linkId, data.model_dump(exclude_unset=True)) @@ -1288,14 +1288,14 @@ async def update_position_document( @router.delete("/{instanceId}/position-documents/{linkId}") @limiter.limit("10/minute") -async def delete_position_document( +def delete_position_document( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), linkId: str = Path(...), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Delete a position-document link.""" - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) existing = interface.getPositionDocument(linkId) @@ -1314,14 +1314,14 @@ async def delete_position_document( from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext -async def _validateInstanceAdmin(instanceId: str, context: RequestContext) -> str: +def _validateInstanceAdmin(instanceId: str, context: RequestContext) -> str: """ Validate that the user has admin access to the feature instance. Returns the mandateId if authorized. This checks for the RESOURCE permission 'instance-roles.manage'. """ - mandateId = await _validateInstanceAccess(instanceId, context) + mandateId = _validateInstanceAccess(instanceId, context) # SysAdmin always has access if context.user.isSysAdmin: @@ -1350,7 +1350,7 @@ async def _validateInstanceAdmin(instanceId: str, context: RequestContext) -> st @router.get("/{instanceId}/instance-roles", response_model=PaginatedResponse) @limiter.limit("30/minute") -async def get_instance_roles( +def get_instance_roles( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), context: RequestContext = Depends(getRequestContext) @@ -1359,7 +1359,7 @@ async def get_instance_roles( Get all roles for this feature instance. Requires feature admin permission. """ - mandateId = await _validateInstanceAdmin(instanceId, context) + mandateId = _validateInstanceAdmin(instanceId, context) rootInterface = getRootInterface() @@ -1374,14 +1374,14 @@ async def get_instance_roles( @router.get("/{instanceId}/instance-roles/{roleId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def get_instance_role( +def get_instance_role( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(..., description="Role ID"), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Get a specific instance role.""" - mandateId = await _validateInstanceAdmin(instanceId, context) + mandateId = _validateInstanceAdmin(instanceId, context) rootInterface = getRootInterface() role = rootInterface.getRole(roleId) @@ -1398,7 +1398,7 @@ async def get_instance_role( @router.get("/{instanceId}/instance-roles/{roleId}/rules", response_model=PaginatedResponse) @limiter.limit("30/minute") -async def get_instance_role_rules( +def get_instance_role_rules( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(..., description="Role ID"), @@ -1408,7 +1408,7 @@ async def get_instance_role_rules( Get all AccessRules for a specific instance role. Requires feature admin permission. """ - mandateId = await _validateInstanceAdmin(instanceId, context) + mandateId = _validateInstanceAdmin(instanceId, context) rootInterface = getRootInterface() @@ -1428,7 +1428,7 @@ async def get_instance_role_rules( @router.post("/{instanceId}/instance-roles/{roleId}/rules", response_model=Dict[str, Any], status_code=201) @limiter.limit("10/minute") -async def create_instance_role_rule( +def create_instance_role_rule( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(..., description="Role ID"), @@ -1439,7 +1439,7 @@ async def create_instance_role_rule( Create a new AccessRule for an instance role. Requires feature admin permission. """ - mandateId = await _validateInstanceAdmin(instanceId, context) + mandateId = _validateInstanceAdmin(instanceId, context) rootInterface = getRootInterface() @@ -1477,7 +1477,7 @@ async def create_instance_role_rule( @router.put("/{instanceId}/instance-roles/{roleId}/rules/{ruleId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def update_instance_role_rule( +def update_instance_role_rule( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(..., description="Role ID"), @@ -1490,7 +1490,7 @@ async def update_instance_role_rule( Only view, read, create, update, delete can be changed. Requires feature admin permission. """ - mandateId = await _validateInstanceAdmin(instanceId, context) + mandateId = _validateInstanceAdmin(instanceId, context) rootInterface = getRootInterface() @@ -1530,7 +1530,7 @@ async def update_instance_role_rule( @router.delete("/{instanceId}/instance-roles/{roleId}/rules/{ruleId}") @limiter.limit("10/minute") -async def delete_instance_role_rule( +def delete_instance_role_rule( request: Request, instanceId: str = Path(..., description="Feature Instance ID"), roleId: str = Path(..., description="Role ID"), @@ -1541,7 +1541,7 @@ async def delete_instance_role_rule( Delete an AccessRule for an instance role. Requires feature admin permission. """ - mandateId = await _validateInstanceAdmin(instanceId, context) + mandateId = _validateInstanceAdmin(instanceId, context) rootInterface = getRootInterface() diff --git a/modules/routes/routeAdmin.py b/modules/routes/routeAdmin.py index 878dbd66..ed5bf42c 100644 --- a/modules/routes/routeAdmin.py +++ b/modules/routes/routeAdmin.py @@ -33,7 +33,7 @@ router.mount( @router.get("/") @limiter.limit("30/minute") -async def root(request: Request) -> Dict[str, str]: +def root(request: Request) -> Dict[str, str]: """API status endpoint""" # Validate required configuration values allowedOrigins = APP_CONFIG.get("APP_ALLOWED_ORIGINS") @@ -51,7 +51,7 @@ async def root(request: Request) -> Dict[str, str]: @router.get("/api/environment") @limiter.limit("30/minute") -async def get_environment( +def get_environment( request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser) ) -> Dict[str, str]: """Get environment configuration for frontend""" @@ -82,13 +82,13 @@ async def get_environment( @router.options("/{fullPath:path}") @limiter.limit("60/minute") -async def options_route(request: Request, fullPath: str) -> Response: +def options_route(request: Request, fullPath: str) -> Response: return Response(status_code=200) @router.get("/favicon.ico") @limiter.limit("30/minute") -async def favicon(request: Request) -> FileResponse: +def favicon(request: Request) -> FileResponse: favicon_path = staticFolder / "favicon.ico" if not favicon_path.exists(): raise HTTPException(status_code=404, detail="Favicon not found") diff --git a/modules/routes/routeAdminAutomationEvents.py b/modules/routes/routeAdminAutomationEvents.py index e8bb9291..7765d621 100644 --- a/modules/routes/routeAdminAutomationEvents.py +++ b/modules/routes/routeAdminAutomationEvents.py @@ -33,7 +33,7 @@ router = APIRouter( @router.get("") @limiter.limit("30/minute") -async def get_all_automation_events( +def get_all_automation_events( request: Request, currentUser: User = Depends(requireSysAdmin) ) -> List[Dict[str, Any]]: @@ -107,7 +107,7 @@ async def sync_all_automation_events( @router.post("/{eventId}/remove") @limiter.limit("10/minute") -async def remove_event( +def remove_event( request: Request, eventId: str = Path(..., description="Event ID to remove"), currentUser: User = Depends(requireSysAdmin) diff --git a/modules/routes/routeAdminFeatures.py b/modules/routes/routeAdminFeatures.py index 87582b9e..3adc8025 100644 --- a/modules/routes/routeAdminFeatures.py +++ b/modules/routes/routeAdminFeatures.py @@ -67,7 +67,7 @@ class SyncRolesResult(BaseModel): @router.get("/", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def list_features( +def list_features( request: Request, context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: @@ -105,7 +105,7 @@ class FeaturesMyResponse(BaseModel): @router.get("/my", response_model=FeaturesMyResponse) @limiter.limit("60/minute") -async def get_my_feature_instances( +def get_my_feature_instances( request: Request, context: RequestContext = Depends(getRequestContext) ) -> FeaturesMyResponse: @@ -332,7 +332,7 @@ def _mergeAccessLevel(current: str, new: str) -> str: @router.post("/", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def create_feature( +def create_feature( request: Request, code: str = Query(..., description="Unique feature code"), label: Dict[str, str] = None, @@ -387,7 +387,7 @@ async def create_feature( @router.get("/instances", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def list_feature_instances( +def list_feature_instances( request: Request, featureCode: Optional[str] = Query(None, description="Filter by feature code"), context: RequestContext = Depends(getRequestContext) @@ -429,7 +429,7 @@ async def list_feature_instances( @router.get("/instances/{instanceId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def get_feature_instance( +def get_feature_instance( request: Request, instanceId: str, context: RequestContext = Depends(getRequestContext) @@ -473,7 +473,7 @@ async def get_feature_instance( @router.post("/instances", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def create_feature_instance( +def create_feature_instance( request: Request, data: FeatureInstanceCreate, context: RequestContext = Depends(getRequestContext) @@ -540,7 +540,7 @@ async def create_feature_instance( @router.delete("/instances/{instanceId}", response_model=Dict[str, str]) @limiter.limit("10/minute") -async def delete_feature_instance( +def delete_feature_instance( request: Request, instanceId: str, context: RequestContext = Depends(getRequestContext) @@ -605,7 +605,7 @@ class FeatureInstanceUpdate(BaseModel): @router.put("/instances/{instanceId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def updateFeatureInstance( +def updateFeatureInstance( request: Request, instanceId: str, data: FeatureInstanceUpdate, @@ -682,7 +682,7 @@ async def updateFeatureInstance( @router.post("/instances/{instanceId}/sync-roles", response_model=SyncRolesResult) @limiter.limit("10/minute") -async def sync_instance_roles( +def sync_instance_roles( request: Request, instanceId: str, addOnly: bool = Query(True, description="Only add missing roles, don't remove extras"), @@ -749,7 +749,7 @@ async def sync_instance_roles( @router.get("/templates/roles", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def list_template_roles( +def list_template_roles( request: Request, featureCode: Optional[str] = Query(None, description="Filter by feature code"), sysAdmin: User = Depends(requireSysAdmin) @@ -779,7 +779,7 @@ async def list_template_roles( @router.post("/templates/roles", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def create_template_role( +def create_template_role( request: Request, roleLabel: str = Query(..., description="Role label (e.g., 'admin', 'viewer')"), featureCode: str = Query(..., description="Feature code this role belongs to"), @@ -864,7 +864,7 @@ class FeatureInstanceUserUpdate(BaseModel): @router.get("/instances/{instanceId}/users", response_model=List[FeatureInstanceUserResponse]) @limiter.limit("60/minute") -async def list_feature_instance_users( +def list_feature_instance_users( request: Request, instanceId: str, context: RequestContext = Depends(getRequestContext) @@ -942,7 +942,7 @@ async def list_feature_instance_users( @router.post("/instances/{instanceId}/users", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def add_user_to_feature_instance( +def add_user_to_feature_instance( request: Request, instanceId: str, data: FeatureInstanceUserCreate, @@ -1043,7 +1043,7 @@ async def add_user_to_feature_instance( @router.delete("/instances/{instanceId}/users/{userId}", response_model=Dict[str, str]) @limiter.limit("30/minute") -async def remove_user_from_feature_instance( +def remove_user_from_feature_instance( request: Request, instanceId: str, userId: str, @@ -1121,7 +1121,7 @@ async def remove_user_from_feature_instance( @router.put("/instances/{instanceId}/users/{userId}/roles", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def update_feature_instance_user_roles( +def update_feature_instance_user_roles( request: Request, instanceId: str, userId: str, @@ -1216,7 +1216,7 @@ async def update_feature_instance_user_roles( @router.get("/instances/{instanceId}/available-roles", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_feature_instance_available_roles( +def get_feature_instance_available_roles( request: Request, instanceId: str, context: RequestContext = Depends(getRequestContext) @@ -1280,7 +1280,7 @@ async def get_feature_instance_available_roles( @router.get("/{featureCode}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def get_feature( +def get_feature( request: Request, featureCode: str, context: RequestContext = Depends(getRequestContext) diff --git a/modules/routes/routeAdminRbacExport.py b/modules/routes/routeAdminRbacExport.py index 28caf8c8..d22a6ba7 100644 --- a/modules/routes/routeAdminRbacExport.py +++ b/modules/routes/routeAdminRbacExport.py @@ -72,7 +72,7 @@ class RbacImportResult(BaseModel): @router.get("/export/global", response_model=RbacExportData) @limiter.limit("10/minute") -async def export_global_rbac( +def export_global_rbac( request: Request, sysAdmin: User = Depends(requireSysAdmin) ) -> RbacExportData: @@ -281,7 +281,7 @@ async def import_global_rbac( @router.get("/export/mandate", response_model=RbacExportData) @limiter.limit("10/minute") -async def export_mandate_rbac( +def export_mandate_rbac( request: Request, includeFeatureInstances: bool = True, context: RequestContext = Depends(getRequestContext) diff --git a/modules/routes/routeAdminRbacRoles.py b/modules/routes/routeAdminRbacRoles.py index 75e00cd5..97830991 100644 --- a/modules/routes/routeAdminRbacRoles.py +++ b/modules/routes/routeAdminRbacRoles.py @@ -68,7 +68,7 @@ router = APIRouter( @router.get("/", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def list_roles( +def list_roles( request: Request, currentUser: User = Depends(requireSysAdmin) ) -> List[Dict[str, Any]]: @@ -113,7 +113,7 @@ async def list_roles( @router.get("/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_role_options( +def get_role_options( request: Request, currentUser: User = Depends(requireSysAdmin) ) -> List[Dict[str, Any]]: @@ -154,7 +154,7 @@ async def get_role_options( @router.post("/", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def create_role( +def create_role( request: Request, role: Role = Body(...), currentUser: User = Depends(requireSysAdmin) @@ -198,7 +198,7 @@ async def create_role( @router.get("/{roleId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def get_role( +def get_role( request: Request, roleId: str = Path(..., description="Role ID"), currentUser: User = Depends(requireSysAdmin) @@ -242,7 +242,7 @@ async def get_role( @router.put("/{roleId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def update_role( +def update_role( request: Request, roleId: str = Path(..., description="Role ID"), role: Role = Body(...), @@ -290,7 +290,7 @@ async def update_role( @router.delete("/{roleId}", response_model=Dict[str, str]) @limiter.limit("30/minute") -async def delete_role( +def delete_role( request: Request, roleId: str = Path(..., description="Role ID"), currentUser: User = Depends(requireSysAdmin) @@ -334,7 +334,7 @@ async def delete_role( @router.get("/users", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def list_users_with_roles( +def list_users_with_roles( request: Request, roleLabel: Optional[str] = Query(None, description="Filter by role label"), mandateId: Optional[str] = Query(None, description="Filter by mandate ID (via UserMandate)"), @@ -396,7 +396,7 @@ async def list_users_with_roles( @router.get("/users/{userId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def get_user_roles( +def get_user_roles( request: Request, userId: str = Path(..., description="User ID"), currentUser: User = Depends(requireSysAdmin) @@ -446,7 +446,7 @@ async def get_user_roles( @router.put("/users/{userId}/roles", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def update_user_roles( +def update_user_roles( request: Request, userId: str = Path(..., description="User ID"), newRoleLabels: List[str] = Body(..., description="List of role labels to assign"), @@ -540,7 +540,7 @@ async def update_user_roles( @router.post("/users/{userId}/roles/{roleLabel}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def add_user_role( +def add_user_role( request: Request, userId: str = Path(..., description="User ID"), roleLabel: str = Path(..., description="Role label to add"), @@ -619,7 +619,7 @@ async def add_user_role( @router.delete("/users/{userId}/roles/{roleLabel}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def remove_user_role( +def remove_user_role( request: Request, userId: str = Path(..., description="User ID"), roleLabel: str = Path(..., description="Role label to remove"), @@ -693,7 +693,7 @@ async def remove_user_role( @router.get("/roles/{roleLabel}/users", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_users_with_role( +def get_users_with_role( request: Request, roleLabel: str = Path(..., description="Role label"), mandateId: Optional[str] = Query(None, description="Filter by mandate ID (via UserMandate)"), diff --git a/modules/routes/routeAdminRbacRules.py b/modules/routes/routeAdminRbacRules.py index 82cc13d7..1feb64a2 100644 --- a/modules/routes/routeAdminRbacRules.py +++ b/modules/routes/routeAdminRbacRules.py @@ -35,7 +35,7 @@ router = APIRouter( @router.get("/permissions", response_model=UserPermissions) @limiter.limit("300/minute") # Raised from 60 - sidebar checks many pages individually -async def get_permissions( +def get_permissions( request: Request, context: str = Query(..., description="Context type: DATA, UI, or RESOURCE"), item: Optional[str] = Query(None, description="Item identifier (table name, UI path, or resource path)"), @@ -101,7 +101,7 @@ async def get_permissions( @router.get("/permissions/all", response_model=Dict[str, Any]) @limiter.limit("120/minute") # Raised from 30 - optimized endpoint for bulk permission fetch -async def get_all_permissions( +def get_all_permissions( request: Request, context: Optional[str] = Query(None, description="Context type: UI or RESOURCE (if not provided, returns both)"), reqContext: RequestContext = Depends(getRequestContext) @@ -293,7 +293,7 @@ async def get_all_permissions( @router.get("/rules", response_model=PaginatedResponse) @limiter.limit("30/minute") -async def get_access_rules( +def get_access_rules( request: Request, roleLabel: Optional[str] = Query(None, description="Filter by role label"), context: Optional[str] = Query(None, description="Filter by context (DATA, UI, RESOURCE)"), @@ -382,7 +382,7 @@ async def get_access_rules( @router.get("/rules/by-role/{roleId}", response_model=PaginatedResponse) @limiter.limit("30/minute") -async def get_access_rules_by_role( +def get_access_rules_by_role( request: Request, roleId: str = Path(..., description="Role ID to get rules for"), currentUser: User = Depends(requireSysAdmin) @@ -420,7 +420,7 @@ async def get_access_rules_by_role( @router.get("/rules/{ruleId}", response_model=dict) @limiter.limit("30/minute") -async def get_access_rule( +def get_access_rule( request: Request, ruleId: str = Path(..., description="Access rule ID"), currentUser: User = Depends(requireSysAdmin) @@ -462,7 +462,7 @@ async def get_access_rule( @router.post("/rules", response_model=dict) @limiter.limit("30/minute") -async def create_access_rule( +def create_access_rule( request: Request, accessRuleData: dict = Body(..., description="Access rule data"), currentUser: User = Depends(requireSysAdmin) @@ -528,7 +528,7 @@ async def create_access_rule( @router.put("/rules/{ruleId}", response_model=dict) @limiter.limit("30/minute") -async def update_access_rule( +def update_access_rule( request: Request, ruleId: str = Path(..., description="Access rule ID"), accessRuleData: dict = Body(..., description="Updated access rule data"), @@ -611,7 +611,7 @@ async def update_access_rule( @router.delete("/rules/{ruleId}") @limiter.limit("30/minute") -async def delete_access_rule( +def delete_access_rule( request: Request, ruleId: str = Path(..., description="Access rule ID"), currentUser: User = Depends(requireSysAdmin) @@ -669,7 +669,7 @@ async def delete_access_rule( @router.get("/roles", response_model=PaginatedResponse) @limiter.limit("60/minute") -async def list_roles( +def list_roles( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), includeTemplates: bool = Query(False, description="Include feature template roles"), @@ -838,7 +838,7 @@ async def list_roles( @router.get("/roles/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_role_options( +def get_role_options( request: Request, currentUser: User = Depends(requireSysAdmin) ) -> List[Dict[str, Any]]: @@ -879,7 +879,7 @@ async def get_role_options( @router.post("/roles", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def create_role( +def create_role( request: Request, role: Role = Body(...), currentUser: User = Depends(requireSysAdmin) @@ -928,7 +928,7 @@ async def create_role( @router.get("/roles/{roleId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def get_role( +def get_role( request: Request, roleId: str = Path(..., description="Role ID"), currentUser: User = Depends(requireSysAdmin) @@ -975,7 +975,7 @@ async def get_role( @router.put("/roles/{roleId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def update_role( +def update_role( request: Request, roleId: str = Path(..., description="Role ID"), role: Role = Body(...), @@ -1028,7 +1028,7 @@ async def update_role( @router.delete("/roles/{roleId}", response_model=Dict[str, str]) @limiter.limit("30/minute") -async def delete_role( +def delete_role( request: Request, roleId: str = Path(..., description="Role ID"), currentUser: User = Depends(requireSysAdmin) @@ -1078,7 +1078,7 @@ async def delete_role( @router.get("/catalog/objects", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def getCatalogObjects( +def getCatalogObjects( request: Request, context: Optional[str] = Query(None, description="Filter by context (DATA, UI, RESOURCE)"), featureCode: Optional[str] = Query(None, description="Filter by feature code"), @@ -1170,7 +1170,7 @@ async def getCatalogObjects( @router.get("/catalog/stats", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def getCatalogStats( +def getCatalogStats( request: Request, currentUser: User = Depends(requireSysAdmin) ) -> Dict[str, Any]: @@ -1200,7 +1200,7 @@ async def getCatalogStats( @router.post("/cleanup/duplicate-rules", response_model=dict) @limiter.limit("5/minute") -async def cleanup_duplicate_access_rules( +def cleanup_duplicate_access_rules( request: Request, dryRun: bool = Query(True, description="If true, only report duplicates without deleting"), currentUser: User = Depends(requireSysAdmin) diff --git a/modules/routes/routeAdminUserAccessOverview.py b/modules/routes/routeAdminUserAccessOverview.py index 372e2193..b330d57e 100644 --- a/modules/routes/routeAdminUserAccessOverview.py +++ b/modules/routes/routeAdminUserAccessOverview.py @@ -69,7 +69,7 @@ def _getRoleScopePriority(scope: str) -> int: @router.get("/users", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def listUsersForOverview( +def listUsersForOverview( request: Request, currentUser: User = Depends(requireSysAdmin) ) -> List[Dict[str, Any]]: @@ -112,7 +112,7 @@ async def listUsersForOverview( @router.get("/{userId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def getUserAccessOverview( +def getUserAccessOverview( request: Request, userId: str = Path(..., description="User ID to get access overview for"), mandateId: Optional[str] = Query(None, description="Filter by mandate ID"), @@ -410,7 +410,7 @@ async def getUserAccessOverview( @router.get("/{userId}/effective-permissions", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def getEffectivePermissions( +def getEffectivePermissions( request: Request, userId: str = Path(..., description="User ID"), mandateId: str = Query(..., description="Mandate ID context"), diff --git a/modules/routes/routeAttributes.py b/modules/routes/routeAttributes.py index 10f93ce6..e877e512 100644 --- a/modules/routes/routeAttributes.py +++ b/modules/routes/routeAttributes.py @@ -22,7 +22,7 @@ router = APIRouter( @router.get("/{entityType}", response_model=AttributeResponse) @limiter.limit("30/minute") -async def get_entity_attributes( +def get_entity_attributes( request: Request, entityType: str = Path(..., description="Type of entity (e.g. prompt)") ) -> AttributeResponse: @@ -76,7 +76,7 @@ async def get_entity_attributes( @router.options("/{entityType}") @limiter.limit("60/minute") -async def options_entity_attributes( +def options_entity_attributes( request: Request, entityType: str = Path(..., description="Type of entity (e.g. prompt)") ) -> Response: diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index bd47c791..26133704 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -164,7 +164,7 @@ router = APIRouter( @router.get("/balance", response_model=List[BillingBalanceResponse]) @limiter.limit("60/minute") -async def getBalance( +def getBalance( request: Request, ctx: RequestContext = Depends(getRequestContext) ): @@ -189,7 +189,7 @@ async def getBalance( @router.get("/balance/{targetMandateId}", response_model=BillingBalanceResponse) @limiter.limit("60/minute") -async def getBalanceForMandate( +def getBalanceForMandate( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext) @@ -230,7 +230,7 @@ async def getBalanceForMandate( @router.get("/transactions", response_model=List[TransactionResponse]) @limiter.limit("30/minute") -async def getTransactions( +def getTransactions( request: Request, limit: int = Query(default=50, ge=1, le=500), offset: int = Query(default=0, ge=0), @@ -276,7 +276,7 @@ async def getTransactions( @router.get("/statistics/{period}", response_model=UsageReportResponse) @limiter.limit("30/minute") -async def getStatistics( +def getStatistics( request: Request, period: str = Path(..., description="Period: 'day', 'month', or 'year'"), year: int = Query(..., description="Year"), @@ -361,7 +361,7 @@ async def getStatistics( @router.get("/providers", response_model=List[str]) @limiter.limit("60/minute") -async def getAllowedProviders( +def getAllowedProviders( request: Request, ctx: RequestContext = Depends(getRequestContext) ): @@ -388,7 +388,7 @@ async def getAllowedProviders( @router.get("/admin/settings/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def getSettingsAdmin( +def getSettingsAdmin( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext), @@ -415,7 +415,7 @@ async def getSettingsAdmin( @router.post("/admin/settings/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def createOrUpdateSettings( +def createOrUpdateSettings( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), settingsUpdate: BillingSettingsUpdate = Body(...), @@ -462,7 +462,7 @@ async def createOrUpdateSettings( @router.post("/admin/credit/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def addCredit( +def addCredit( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), creditRequest: CreditAddRequest = Body(...), @@ -526,7 +526,7 @@ async def addCredit( @router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary]) @limiter.limit("30/minute") -async def getAccounts( +def getAccounts( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext), @@ -572,7 +572,7 @@ class MandateUserSummary(BaseModel): @router.get("/admin/users/{targetMandateId}", response_model=List[MandateUserSummary]) @limiter.limit("30/minute") -async def getUsersForMandate( +def getUsersForMandate( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext), @@ -627,7 +627,7 @@ async def getUsersForMandate( @router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse]) @limiter.limit("30/minute") -async def getTransactionsAdmin( +def getTransactionsAdmin( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), limit: int = Query(default=100, ge=1, le=1000), @@ -669,7 +669,7 @@ async def getTransactionsAdmin( @router.get("/view/mandates/balances", response_model=List[MandateBalanceResponse]) @limiter.limit("30/minute") -async def getMandateViewBalances( +def getMandateViewBalances( request: Request, ctx: RequestContext = Depends(getRequestContext), _admin = Depends(requireSysAdmin) @@ -691,7 +691,7 @@ async def getMandateViewBalances( @router.get("/view/mandates/transactions", response_model=List[TransactionResponse]) @limiter.limit("30/minute") -async def getMandateViewTransactions( +def getMandateViewTransactions( request: Request, limit: int = Query(default=100, ge=1, le=1000), ctx: RequestContext = Depends(getRequestContext), @@ -734,7 +734,7 @@ async def getMandateViewTransactions( @router.get("/view/users/balances", response_model=List[UserBalanceResponse]) @limiter.limit("30/minute") -async def getUserViewBalances( +def getUserViewBalances( request: Request, ctx: RequestContext = Depends(getRequestContext) ): @@ -793,7 +793,7 @@ class ViewStatisticsResponse(BaseModel): @router.get("/view/statistics") @limiter.limit("30/minute") -async def getUserViewStatistics( +def getUserViewStatistics( request: Request, period: str = Query(default="month", description="Period: 'day' or 'month'"), year: int = Query(default=None, description="Year"), @@ -962,7 +962,7 @@ async def getUserViewStatistics( @router.get("/view/users/transactions", response_model=PaginatedResponse[UserTransactionResponse]) @limiter.limit("30/minute") -async def getUserViewTransactions( +def getUserViewTransactions( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), ctx: RequestContext = Depends(getRequestContext) diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py index 95bbd014..099b04c2 100644 --- a/modules/routes/routeDataConnections.py +++ b/modules/routes/routeDataConnections.py @@ -84,7 +84,7 @@ router = APIRouter( @router.get("/statuses/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_connection_status_options( +def get_connection_status_options( request: Request, currentUser: User = Depends(getCurrentUser) ) -> List[Dict[str, Any]]: @@ -100,7 +100,7 @@ async def get_connection_status_options( @router.get("/authorities/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_auth_authority_options( +def get_auth_authority_options( request: Request, currentUser: User = Depends(getCurrentUser) ) -> List[Dict[str, Any]]: @@ -288,7 +288,7 @@ async def get_connections( @router.post("/", response_model=UserConnection) @limiter.limit("10/minute") -async def create_connection( +def create_connection( request: Request, connection_data: Dict[str, Any] = Body(...), currentUser: User = Depends(getCurrentUser) @@ -344,7 +344,7 @@ async def create_connection( @router.put("/{connectionId}", response_model=UserConnection) @limiter.limit("10/minute") -async def update_connection( +def update_connection( request: Request, connectionId: str = Path(..., description="The ID of the connection to update"), connection_data: Dict[str, Any] = Body(...), @@ -416,7 +416,7 @@ async def update_connection( @router.post("/{connectionId}/connect") @limiter.limit("10/minute") -async def connect_service( +def connect_service( request: Request, connectionId: str = Path(..., description="The ID of the connection to connect"), currentUser: User = Depends(getCurrentUser) @@ -482,7 +482,7 @@ async def connect_service( @router.post("/{connectionId}/disconnect") @limiter.limit("10/minute") -async def disconnect_service( +def disconnect_service( request: Request, connectionId: str = Path(..., description="The ID of the connection to disconnect"), currentUser: User = Depends(getCurrentUser) @@ -532,7 +532,7 @@ async def disconnect_service( @router.delete("/{connectionId}") @limiter.limit("10/minute") -async def delete_connection( +def delete_connection( request: Request, connectionId: str = Path(..., description="The ID of the connection to delete"), currentUser: User = Depends(getCurrentUser) diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py index 1a84b7e4..49d7e365 100644 --- a/modules/routes/routeDataFiles.py +++ b/modules/routes/routeDataFiles.py @@ -37,7 +37,7 @@ router = APIRouter( @router.get("/list", response_model=PaginatedResponse[FileItem]) @limiter.limit("30/minute") -async def get_files( +def get_files( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), currentUser: User = Depends(getCurrentUser) @@ -168,7 +168,7 @@ async def upload_file( @router.get("/{fileId}", response_model=FileItem) @limiter.limit("30/minute") -async def get_file( +def get_file( request: Request, fileId: str = Path(..., description="ID of the file"), currentUser: User = Depends(getCurrentUser) @@ -214,7 +214,7 @@ async def get_file( @router.put("/{fileId}", response_model=FileItem) @limiter.limit("10/minute") -async def update_file( +def update_file( request: Request, fileId: str = Path(..., description="ID of the file to update"), file_info: Dict[str, Any] = Body(...), @@ -262,7 +262,7 @@ async def update_file( @router.delete("/{fileId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def delete_file( +def delete_file( request: Request, fileId: str = Path(..., description="ID of the file to delete"), currentUser: User = Depends(getCurrentUser) @@ -289,7 +289,7 @@ async def delete_file( @router.get("/stats", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def get_file_stats( +def get_file_stats( request: Request, currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: @@ -327,7 +327,7 @@ async def get_file_stats( @router.get("/{fileId}/download") @limiter.limit("30/minute") -async def download_file( +def download_file( request: Request, fileId: str = Path(..., description="ID of the file to download"), currentUser: User = Depends(getCurrentUser) @@ -375,7 +375,7 @@ async def download_file( @router.get("/{fileId}/preview", response_model=FilePreview) @limiter.limit("30/minute") -async def preview_file( +def preview_file( request: Request, fileId: str = Path(..., description="ID of the file to preview"), currentUser: User = Depends(getCurrentUser) diff --git a/modules/routes/routeDataMandates.py b/modules/routes/routeDataMandates.py index 38877a9f..8d2c4a2b 100644 --- a/modules/routes/routeDataMandates.py +++ b/modules/routes/routeDataMandates.py @@ -76,7 +76,7 @@ router = APIRouter( @router.get("/", response_model=PaginatedResponse[Mandate]) @limiter.limit("30/minute") -async def get_mandates( +def get_mandates( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), currentUser: User = Depends(requireSysAdmin) @@ -140,7 +140,7 @@ async def get_mandates( @router.get("/{mandateId}", response_model=Mandate) @limiter.limit("30/minute") -async def get_mandate( +def get_mandate( request: Request, mandateId: str = Path(..., description="ID of the mandate"), currentUser: User = Depends(requireSysAdmin) @@ -171,7 +171,7 @@ async def get_mandate( @router.post("/", response_model=Mandate) @limiter.limit("10/minute") -async def create_mandate( +def create_mandate( request: Request, mandateData: dict = Body(..., description="Mandate data with at least 'name' field"), currentUser: User = Depends(requireSysAdmin) @@ -224,7 +224,7 @@ async def create_mandate( @router.put("/{mandateId}", response_model=Mandate) @limiter.limit("10/minute") -async def update_mandate( +def update_mandate( request: Request, mandateId: str = Path(..., description="ID of the mandate to update"), mandateData: dict = Body(..., description="Mandate update data"), @@ -270,7 +270,7 @@ async def update_mandate( @router.delete("/{mandateId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def delete_mandate( +def delete_mandate( request: Request, mandateId: str = Path(..., description="ID of the mandate to delete"), currentUser: User = Depends(requireSysAdmin) @@ -324,7 +324,7 @@ async def delete_mandate( @router.get("/{targetMandateId}/users") @limiter.limit("60/minute") -async def list_mandate_users( +def list_mandate_users( request: Request, targetMandateId: str = Path(..., description="ID of the mandate"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), @@ -493,7 +493,7 @@ async def list_mandate_users( @router.post("/{targetMandateId}/users", response_model=UserMandateResponse) @limiter.limit("30/minute") -async def add_user_to_mandate( +def add_user_to_mandate( request: Request, targetMandateId: str = Path(..., description="ID of the mandate"), data: UserMandateCreate = Body(...), @@ -603,7 +603,7 @@ async def add_user_to_mandate( @router.delete("/{targetMandateId}/users/{targetUserId}", response_model=Dict[str, str]) @limiter.limit("30/minute") -async def remove_user_from_mandate( +def remove_user_from_mandate( request: Request, targetMandateId: str = Path(..., description="ID of the mandate"), targetUserId: str = Path(..., description="ID of the user to remove"), @@ -681,7 +681,7 @@ async def remove_user_from_mandate( @router.put("/{targetMandateId}/users/{targetUserId}/roles", response_model=UserMandateResponse) @limiter.limit("30/minute") -async def update_user_roles_in_mandate( +def update_user_roles_in_mandate( request: Request, targetMandateId: str = Path(..., description="ID of the mandate"), targetUserId: str = Path(..., description="ID of the user"), diff --git a/modules/routes/routeDataPrompts.py b/modules/routes/routeDataPrompts.py index 48902e66..4aad221d 100644 --- a/modules/routes/routeDataPrompts.py +++ b/modules/routes/routeDataPrompts.py @@ -27,7 +27,7 @@ router = APIRouter( @router.get("", response_model=PaginatedResponse[Prompt]) @limiter.limit("30/minute") -async def get_prompts( +def get_prompts( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), currentUser: User = Depends(getCurrentUser) @@ -83,7 +83,7 @@ async def get_prompts( @router.post("", response_model=Prompt) @limiter.limit("10/minute") -async def create_prompt( +def create_prompt( request: Request, prompt: Prompt, currentUser: User = Depends(getCurrentUser) @@ -98,7 +98,7 @@ async def create_prompt( @router.get("/{promptId}", response_model=Prompt) @limiter.limit("30/minute") -async def get_prompt( +def get_prompt( request: Request, promptId: str = Path(..., description="ID of the prompt"), currentUser: User = Depends(getCurrentUser) @@ -118,7 +118,7 @@ async def get_prompt( @router.put("/{promptId}", response_model=Prompt) @limiter.limit("10/minute") -async def update_prompt( +def update_prompt( request: Request, promptId: str = Path(..., description="ID of the prompt to update"), promptData: Prompt = Body(...), @@ -154,7 +154,7 @@ async def update_prompt( @router.delete("/{promptId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def delete_prompt( +def delete_prompt( request: Request, promptId: str = Path(..., description="ID of the prompt to delete"), currentUser: User = Depends(getCurrentUser) diff --git a/modules/routes/routeDataUsers.py b/modules/routes/routeDataUsers.py index 5e78d12a..b269e57e 100644 --- a/modules/routes/routeDataUsers.py +++ b/modules/routes/routeDataUsers.py @@ -153,7 +153,7 @@ router = APIRouter( @router.get("/options", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def get_user_options( +def get_user_options( request: Request, context: RequestContext = Depends(getRequestContext) ) -> List[Dict[str, Any]]: @@ -190,7 +190,7 @@ async def get_user_options( @router.get("/", response_model=PaginatedResponse[User]) @limiter.limit("30/minute") -async def get_users( +def get_users( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), context: RequestContext = Depends(getRequestContext) @@ -304,7 +304,7 @@ async def get_users( @router.get("/{userId}", response_model=User) @limiter.limit("30/minute") -async def get_user( +def get_user( request: Request, userId: str = Path(..., description="ID of the user"), context: RequestContext = Depends(getRequestContext) @@ -356,7 +356,7 @@ class CreateUserRequest(BaseModel): @router.post("", response_model=User) @limiter.limit("10/minute") -async def create_user( +def create_user( request: Request, userData: CreateUserRequest = Body(...), context: RequestContext = Depends(getRequestContext) @@ -396,7 +396,7 @@ async def create_user( @router.put("/{userId}", response_model=User) @limiter.limit("10/minute") -async def update_user( +def update_user( request: Request, userId: str = Path(..., description="ID of the user to update"), userData: User = Body(...), @@ -438,7 +438,7 @@ async def update_user( @router.post("/{userId}/reset-password") @limiter.limit("5/minute") -async def reset_user_password( +def reset_user_password( request: Request, userId: str = Path(..., description="ID of the user to reset password for"), newPassword: str = Body(..., embed=True), @@ -535,7 +535,7 @@ async def reset_user_password( @router.post("/change-password") @limiter.limit("5/minute") -async def change_password( +def change_password( request: Request, currentPassword: str = Body(..., embed=True), newPassword: str = Body(..., embed=True), @@ -614,7 +614,7 @@ async def change_password( @router.post("/{userId}/send-password-link") @limiter.limit("10/minute") -async def send_password_link( +def send_password_link( request: Request, userId: str = Path(..., description="ID of the user to send password setup link"), frontendUrl: str = Body(..., embed=True), @@ -749,7 +749,7 @@ Falls Sie diese Anforderung nicht erwartet haben, kontaktieren Sie bitte Ihren A @router.delete("/{userId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def delete_user( +def delete_user( request: Request, userId: str = Path(..., description="ID of the user to delete"), context: RequestContext = Depends(getRequestContext) diff --git a/modules/routes/routeDataWorkflows.py b/modules/routes/routeDataWorkflows.py index 80ca5986..88b41009 100644 --- a/modules/routes/routeDataWorkflows.py +++ b/modules/routes/routeDataWorkflows.py @@ -50,7 +50,7 @@ def getServiceChat(currentUser: User): # Consolidated endpoint for getting all workflows @router.get("/", response_model=PaginatedResponse[ChatWorkflow]) @limiter.limit("120/minute") -async def get_workflows( +def get_workflows( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), currentUser: User = Depends(getCurrentUser) @@ -123,7 +123,7 @@ async def get_workflows( @router.get("/{workflowId}", response_model=ChatWorkflow) @limiter.limit("120/minute") -async def get_workflow( +def get_workflow( request: Request, workflowId: str = Path(..., description="ID of the workflow"), currentUser: User = Depends(getCurrentUser) @@ -152,7 +152,7 @@ async def get_workflow( @router.put("/{workflowId}", response_model=ChatWorkflow) @limiter.limit("120/minute") -async def update_workflow( +def update_workflow( request: Request, workflowId: str = Path(..., description="ID of the workflow to update"), workflowData: Dict[str, Any] = Body(...), @@ -200,7 +200,7 @@ async def update_workflow( # API Endpoint for workflow status @router.get("/{workflowId}/status", response_model=ChatWorkflow) @limiter.limit("120/minute") -async def get_workflow_status( +def get_workflow_status( request: Request, workflowId: str = Path(..., description="ID of the workflow"), currentUser: User = Depends(getCurrentUser) @@ -274,7 +274,7 @@ async def stop_workflow( # API Endpoint for workflow logs with selective data transfer @router.get("/{workflowId}/logs", response_model=PaginatedResponse[ChatLog]) @limiter.limit("120/minute") -async def get_workflow_logs( +def get_workflow_logs( request: Request, workflowId: str = Path(..., description="ID of the workflow"), logId: Optional[str] = Query(None, description="Optional log ID to get only newer logs (legacy selective data transfer)"), @@ -365,7 +365,7 @@ async def get_workflow_logs( # API Endpoint for workflow messages with selective data transfer @router.get("/{workflowId}/messages", response_model=PaginatedResponse[ChatMessage]) @limiter.limit("120/minute") -async def get_workflow_messages( +def get_workflow_messages( request: Request, workflowId: str = Path(..., description="ID of the workflow"), messageId: Optional[str] = Query(None, description="Optional message ID to get only newer messages (legacy selective data transfer)"), @@ -457,7 +457,7 @@ async def get_workflow_messages( # State 11: Workflow Reset/Deletion endpoint @router.delete("/{workflowId}", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def delete_workflow( +def delete_workflow( request: Request, workflowId: str = Path(..., description="ID of the workflow to delete"), currentUser: User = Depends(getCurrentUser) @@ -516,7 +516,7 @@ async def delete_workflow( @router.delete("/{workflowId}/messages/{messageId}", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def delete_workflow_message( +def delete_workflow_message( request: Request, workflowId: str = Path(..., description="ID of the workflow"), messageId: str = Path(..., description="ID of the message to delete"), @@ -566,7 +566,7 @@ async def delete_workflow_message( @router.delete("/{workflowId}/messages/{messageId}/files/{fileId}", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def delete_file_from_message( +def delete_file_from_message( request: Request, workflowId: str = Path(..., description="ID of the workflow"), messageId: str = Path(..., description="ID of the message"), @@ -615,7 +615,7 @@ async def delete_file_from_message( @router.get("/actions", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def get_all_actions( +def get_all_actions( request: Request, currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: @@ -685,7 +685,7 @@ async def get_all_actions( @router.get("/actions/{method}", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def get_method_actions( +def get_method_actions( request: Request, method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"), currentUser: User = Depends(getCurrentUser) @@ -768,7 +768,7 @@ async def get_method_actions( @router.get("/actions/{method}/{action}", response_model=Dict[str, Any]) @limiter.limit("120/minute") -async def get_action_schema( +def get_action_schema( request: Request, method: str = Path(..., description="Method name (e.g., 'outlook', 'sharepoint')"), action: str = Path(..., description="Action name (e.g., 'readEmails', 'uploadDocument')"), diff --git a/modules/routes/routeGdpr.py b/modules/routes/routeGdpr.py index af0c7199..abc39b1f 100644 --- a/modules/routes/routeGdpr.py +++ b/modules/routes/routeGdpr.py @@ -74,7 +74,7 @@ class DeletionResult(BaseModel): @router.get("/data-export", response_model=DataExportResponse) @limiter.limit("5/minute") -async def export_user_data( +def export_user_data( request: Request, currentUser: User = Depends(getCurrentUser) ) -> DataExportResponse: @@ -215,7 +215,7 @@ async def export_user_data( @router.get("/data-portability") @limiter.limit("5/minute") -async def export_portable_data( +def export_portable_data( request: Request, currentUser: User = Depends(getCurrentUser) ) -> JSONResponse: @@ -296,7 +296,7 @@ async def export_portable_data( @router.delete("/", response_model=DeletionResult) @limiter.limit("1/hour") -async def delete_account( +def delete_account( request: Request, confirmDeletion: bool = False, currentUser: User = Depends(getCurrentUser) @@ -391,7 +391,7 @@ async def delete_account( @router.get("/consent-info", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def get_consent_info( +def get_consent_info( request: Request, currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: diff --git a/modules/routes/routeInvitations.py b/modules/routes/routeInvitations.py index 6a53fb38..095b84fb 100644 --- a/modules/routes/routeInvitations.py +++ b/modules/routes/routeInvitations.py @@ -94,7 +94,7 @@ class InvitationValidation(BaseModel): @router.post("/", response_model=InvitationResponse) @limiter.limit("30/minute") -async def create_invitation( +def create_invitation( request: Request, data: InvitationCreate, context: RequestContext = Depends(getRequestContext) @@ -300,7 +300,7 @@ async def create_invitation( @router.get("/", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def list_invitations( +def list_invitations( request: Request, includeUsed: bool = Query(False, description="Include already used invitations"), includeExpired: bool = Query(False, description="Include expired invitations"), @@ -379,7 +379,7 @@ async def list_invitations( @router.delete("/{invitationId}", response_model=Dict[str, str]) @limiter.limit("30/minute") -async def revoke_invitation( +def revoke_invitation( request: Request, invitationId: str, context: RequestContext = Depends(getRequestContext) @@ -458,7 +458,7 @@ async def revoke_invitation( @router.get("/validate/{token}", response_model=InvitationValidation) @limiter.limit("30/minute") -async def validate_invitation( +def validate_invitation( request: Request, token: str ) -> InvitationValidation: @@ -562,7 +562,7 @@ async def validate_invitation( @router.post("/accept/{token}", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def accept_invitation( +def accept_invitation( request: Request, token: str, currentUser: User = Depends(getCurrentUser) diff --git a/modules/routes/routeMessaging.py b/modules/routes/routeMessaging.py index 419e9ae6..223181e0 100644 --- a/modules/routes/routeMessaging.py +++ b/modules/routes/routeMessaging.py @@ -38,7 +38,7 @@ router = APIRouter( @router.get("/subscriptions", response_model=PaginatedResponse[MessagingSubscription]) @limiter.limit("60/minute") -async def get_subscriptions( +def get_subscriptions( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), currentUser: User = Depends(getCurrentUser) @@ -79,7 +79,7 @@ async def get_subscriptions( @router.post("/subscriptions", response_model=MessagingSubscription) @limiter.limit("60/minute") -async def create_subscription( +def create_subscription( request: Request, subscription: MessagingSubscription, currentUser: User = Depends(getCurrentUser) @@ -95,7 +95,7 @@ async def create_subscription( @router.get("/subscriptions/{subscriptionId}", response_model=MessagingSubscription) @limiter.limit("60/minute") -async def get_subscription( +def get_subscription( request: Request, subscriptionId: str = Path(..., description="ID of the subscription"), currentUser: User = Depends(getCurrentUser) @@ -115,7 +115,7 @@ async def get_subscription( @router.put("/subscriptions/{subscriptionId}", response_model=MessagingSubscription) @limiter.limit("60/minute") -async def update_subscription( +def update_subscription( request: Request, subscriptionId: str = Path(..., description="ID of the subscription to update"), subscriptionData: MessagingSubscription = Body(...), @@ -145,7 +145,7 @@ async def update_subscription( @router.delete("/subscriptions/{subscriptionId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def delete_subscription( +def delete_subscription( request: Request, subscriptionId: str = Path(..., description="ID of the subscription to delete"), currentUser: User = Depends(getCurrentUser) @@ -174,7 +174,7 @@ async def delete_subscription( @router.get("/subscriptions/{subscriptionId}/registrations", response_model=PaginatedResponse[MessagingSubscriptionRegistration]) @limiter.limit("60/minute") -async def get_subscription_registrations( +def get_subscription_registrations( request: Request, subscriptionId: str = Path(..., description="ID of the subscription"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), @@ -219,7 +219,7 @@ async def get_subscription_registrations( @router.post("/subscriptions/{subscriptionId}/subscribe", response_model=MessagingSubscriptionRegistration) @limiter.limit("60/minute") -async def subscribe_user( +def subscribe_user( request: Request, subscriptionId: str = Path(..., description="ID of the subscription"), channel: MessagingChannel = Body(..., embed=True), @@ -241,7 +241,7 @@ async def subscribe_user( @router.delete("/subscriptions/{subscriptionId}/unsubscribe", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def unsubscribe_user( +def unsubscribe_user( request: Request, subscriptionId: str = Path(..., description="ID of the subscription"), channel: MessagingChannel = Body(..., embed=True), @@ -267,7 +267,7 @@ async def unsubscribe_user( @router.get("/registrations", response_model=PaginatedResponse[MessagingSubscriptionRegistration]) @limiter.limit("60/minute") -async def get_my_registrations( +def get_my_registrations( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), currentUser: User = Depends(getCurrentUser) @@ -311,7 +311,7 @@ async def get_my_registrations( @router.put("/registrations/{registrationId}", response_model=MessagingSubscriptionRegistration) @limiter.limit("60/minute") -async def update_registration( +def update_registration( request: Request, registrationId: str = Path(..., description="ID of the registration to update"), registrationData: MessagingSubscriptionRegistration = Body(...), @@ -341,7 +341,7 @@ async def update_registration( @router.delete("/registrations/{registrationId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def delete_registration( +def delete_registration( request: Request, registrationId: str = Path(..., description="ID of the registration to delete"), currentUser: User = Depends(getCurrentUser) @@ -376,7 +376,7 @@ def _getTriggerKey(request: Request) -> str: @router.post("/trigger/{subscriptionId}", response_model=MessagingSubscriptionExecutionResult) @limiter.limit("60/minute", key_func=_getTriggerKey) -async def trigger_subscription( +def trigger_subscription( request: Request, subscriptionId: str = Path(..., description="ID of the subscription to trigger"), eventParameters: Dict[str, Any] = Body(...), @@ -439,7 +439,7 @@ def _hasTriggerPermission(context: RequestContext) -> bool: @router.get("/deliveries", response_model=PaginatedResponse[MessagingDelivery]) @limiter.limit("60/minute") -async def get_deliveries( +def get_deliveries( request: Request, subscriptionId: Optional[str] = Query(None, description="Filter by subscription ID"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), @@ -485,7 +485,7 @@ async def get_deliveries( @router.get("/deliveries/{deliveryId}", response_model=MessagingDelivery) @limiter.limit("60/minute") -async def get_delivery( +def get_delivery( request: Request, deliveryId: str = Path(..., description="ID of the delivery"), currentUser: User = Depends(getCurrentUser) diff --git a/modules/routes/routeNotifications.py b/modules/routes/routeNotifications.py index 4fc09ac4..00e9bbcd 100644 --- a/modules/routes/routeNotifications.py +++ b/modules/routes/routeNotifications.py @@ -120,7 +120,7 @@ def createInvitationNotification( @router.get("", response_model=List[Dict[str, Any]]) @limiter.limit("60/minute") -async def getNotifications( +def getNotifications( request: Request, currentUser: User = Depends(getCurrentUser), status: Optional[str] = None, @@ -161,7 +161,7 @@ async def getNotifications( @router.get("/unread-count", response_model=UnreadCountResponse) @limiter.limit("120/minute") -async def getUnreadCount( +def getUnreadCount( request: Request, currentUser: User = Depends(getCurrentUser) ) -> UnreadCountResponse: @@ -190,7 +190,7 @@ async def getUnreadCount( @router.put("/{notificationId}/read", response_model=Dict[str, Any]) @limiter.limit("60/minute") -async def markAsRead( +def markAsRead( request: Request, notificationId: str, currentUser: User = Depends(getCurrentUser) @@ -241,7 +241,7 @@ async def markAsRead( @router.put("/mark-all-read", response_model=Dict[str, Any]) @limiter.limit("10/minute") -async def markAllAsRead( +def markAllAsRead( request: Request, currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: @@ -283,7 +283,7 @@ async def markAllAsRead( @router.post("/{notificationId}/action", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def executeAction( +def executeAction( request: Request, notificationId: str, actionRequest: NotificationActionRequest, @@ -332,7 +332,7 @@ async def executeAction( actionResult = None if notification.get("type") == NotificationType.INVITATION.value: - actionResult = await _handleInvitationAction( + actionResult = _handleInvitationAction( notification=notification, actionId=actionRequest.actionId, currentUser=currentUser, @@ -370,7 +370,7 @@ async def executeAction( ) -async def _handleInvitationAction( +def _handleInvitationAction( notification: Dict[str, Any], actionId: str, currentUser: User, @@ -488,7 +488,7 @@ async def _handleInvitationAction( @router.delete("/{notificationId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") -async def deleteNotification( +def deleteNotification( request: Request, notificationId: str, currentUser: User = Depends(getCurrentUser) diff --git a/modules/routes/routeSecurityAdmin.py b/modules/routes/routeSecurityAdmin.py index 75490eac..6a01cd9a 100644 --- a/modules/routes/routeSecurityAdmin.py +++ b/modules/routes/routeSecurityAdmin.py @@ -97,7 +97,7 @@ def _getDatabaseConnector(databaseName: str, userId: str = None) -> DatabaseConn @router.get("/tokens") @limiter.limit("30/minute") -async def list_tokens( +def list_tokens( request: Request, currentUser: User = Depends(requireSysAdmin), userId: Optional[str] = None, @@ -137,7 +137,7 @@ async def list_tokens( @router.post("/tokens/revoke/user") @limiter.limit("30/minute") -async def revoke_tokens_by_user( +def revoke_tokens_by_user( request: Request, currentUser: User = Depends(requireSysAdmin), payload: Dict[str, Any] = Body(...) @@ -172,7 +172,7 @@ async def revoke_tokens_by_user( @router.post("/tokens/revoke/session") @limiter.limit("30/minute") -async def revoke_tokens_by_session( +def revoke_tokens_by_session( request: Request, currentUser: User = Depends(requireSysAdmin), payload: Dict[str, Any] = Body(...) @@ -208,7 +208,7 @@ async def revoke_tokens_by_session( @router.post("/tokens/revoke/id") @limiter.limit("30/minute") -async def revoke_token_by_id( +def revoke_token_by_id( request: Request, currentUser: User = Depends(requireSysAdmin), payload: Dict[str, Any] = Body(...) @@ -235,7 +235,7 @@ async def revoke_token_by_id( @router.post("/tokens/revoke/mandate") @limiter.limit("10/minute") -async def revoke_tokens_by_mandate( +def revoke_tokens_by_mandate( request: Request, currentUser: User = Depends(requireSysAdmin), payload: Dict[str, Any] = Body(...) @@ -280,7 +280,7 @@ async def revoke_tokens_by_mandate( @router.get("/logs/{log_name}") @limiter.limit("60/minute") -async def download_log( +def download_log( request: Request, currentUser: User = Depends(requireSysAdmin), log_name: str = "poweron" @@ -309,7 +309,7 @@ async def download_log( @router.get("/databases") @limiter.limit("10/minute") -async def list_databases( +def list_databases( request: Request, currentUser: User = Depends(requireSysAdmin) ) -> Dict[str, Any]: @@ -327,7 +327,7 @@ async def list_databases( @router.get("/databases/{database_name}/tables") @limiter.limit("30/minute") -async def get_database_tables( +def get_database_tables( request: Request, database_name: str, currentUser: User = Depends(requireSysAdmin) @@ -356,7 +356,7 @@ async def get_database_tables( @router.post("/databases/{database_name}/tables/{table_name}/drop") @limiter.limit("10/minute") -async def drop_table( +def drop_table( request: Request, database_name: str, table_name: str, @@ -404,7 +404,7 @@ async def drop_table( @router.post("/databases/drop") @limiter.limit("5/minute") -async def drop_database( +def drop_database( request: Request, currentUser: User = Depends(requireSysAdmin), payload: Dict[str, Any] = Body(...) diff --git a/modules/routes/routeSecurityGoogle.py b/modules/routes/routeSecurityGoogle.py index 4ee634ed..cfaddc22 100644 --- a/modules/routes/routeSecurityGoogle.py +++ b/modules/routes/routeSecurityGoogle.py @@ -93,7 +93,7 @@ SCOPES = [ ] @router.get("/config") -async def get_config(): +def get_config(): """Debug endpoint to check Google OAuth configuration""" return { "client_id": CLIENT_ID, @@ -109,7 +109,7 @@ async def get_config(): @router.get("/login") @limiter.limit("5/minute") -async def login( +def login( request: Request, state: str = Query("login", description="State parameter to distinguish between login and connection flows"), connectionId: Optional[str] = Query(None, description="Connection ID for connection flow") @@ -589,7 +589,7 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo @router.get("/me", response_model=User) @limiter.limit("30/minute") -async def get_current_user( +def get_current_user( request: Request, currentUser: User = Depends(getCurrentUser) ) -> User: @@ -605,7 +605,7 @@ async def get_current_user( @router.post("/logout") @limiter.limit("10/minute") -async def logout( +def logout( request: Request, currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: diff --git a/modules/routes/routeSecurityLocal.py b/modules/routes/routeSecurityLocal.py index 5f132833..46d00152 100644 --- a/modules/routes/routeSecurityLocal.py +++ b/modules/routes/routeSecurityLocal.py @@ -89,7 +89,7 @@ router = APIRouter( @router.post("/login") @limiter.limit("30/minute") -async def login( +def login( request: Request, response: Response, formData: OAuth2PasswordRequestForm = Depends(), @@ -242,7 +242,7 @@ async def login( @router.post("/register") @limiter.limit("10/minute") -async def register_user( +def register_user( request: Request, userData: User = Body(...), frontendUrl: str = Body(..., embed=True) @@ -381,7 +381,7 @@ Falls Sie sich nicht registriert haben, können Sie diese E-Mail ignorieren.""" @router.get("/me", response_model=User) @limiter.limit("30/minute") -async def read_user_me( +def read_user_me( request: Request, currentUser: User = Depends(getCurrentUser) ) -> User: @@ -397,7 +397,7 @@ async def read_user_me( @router.post("/refresh") @limiter.limit("60/minute") -async def refresh_token( +def refresh_token( request: Request, response: Response ) -> Dict[str, Any]: @@ -472,7 +472,7 @@ async def refresh_token( @router.post("/logout") @limiter.limit("30/minute") -async def logout(request: Request, response: Response, currentUser: User = Depends(getCurrentUser)) -> JSONResponse: +def logout(request: Request, response: Response, currentUser: User = Depends(getCurrentUser)) -> JSONResponse: """Logout from local authentication""" try: # Get user interface with current user context @@ -541,7 +541,7 @@ async def logout(request: Request, response: Response, currentUser: User = Depen @router.get("/available") @limiter.limit("10/minute") -async def check_username_availability( +def check_username_availability( request: Request, username: str, authenticationAuthority: str = "local" @@ -573,7 +573,7 @@ async def check_username_availability( @router.post("/password-reset-request") @limiter.limit("5/minute") -async def password_reset_request( +def password_reset_request( request: Request, username: str = Body(..., embed=True), frontendUrl: str = Body(..., embed=True) @@ -653,7 +653,7 @@ Falls Sie diese Anforderung nicht gestellt haben, können Sie diese E-Mail ignor @router.post("/password-reset") @limiter.limit("10/minute") -async def password_reset( +def password_reset( request: Request, token: str = Body(..., embed=True), password: str = Body(..., embed=True) diff --git a/modules/routes/routeSecurityMsft.py b/modules/routes/routeSecurityMsft.py index 0abb2f56..338e2e33 100644 --- a/modules/routes/routeSecurityMsft.py +++ b/modules/routes/routeSecurityMsft.py @@ -66,7 +66,7 @@ SCOPES = [ @router.get("/login") @limiter.limit("5/minute") -async def login( +def login( request: Request, state: str = Query("login", description="State parameter to distinguish between login and connection flows"), connectionId: Optional[str] = Query(None, description="Connection ID for connection flow") @@ -138,7 +138,7 @@ async def login( @router.get("/adminconsent") @limiter.limit("5/minute") -async def adminconsent(request: Request) -> RedirectResponse: +def adminconsent(request: Request) -> RedirectResponse: """Initiate Microsoft Admin Consent flow. An Azure AD admin must visit this URL once to grant consent for the entire tenant. @@ -161,7 +161,7 @@ async def adminconsent(request: Request) -> RedirectResponse: ) @router.get("/adminconsent/callback") -async def adminconsent_callback( +def adminconsent_callback( admin_consent: Optional[str] = Query(None), tenant: Optional[str] = Query(None), error: Optional[str] = Query(None), @@ -603,7 +603,7 @@ async def auth_callback(code: str, state: str, request: Request, response: Respo @router.get("/me", response_model=User) @limiter.limit("30/minute") -async def get_current_user( +def get_current_user( request: Request, currentUser: User = Depends(getCurrentUser) ) -> User: @@ -619,7 +619,7 @@ async def get_current_user( @router.post("/logout") @limiter.limit("10/minute") -async def logout( +def logout( request: Request, currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: @@ -655,7 +655,7 @@ async def logout( @router.post("/cleanup") @limiter.limit("5/minute") -async def cleanup_expired_tokens( +def cleanup_expired_tokens( request: Request, currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: diff --git a/modules/routes/routeSystem.py b/modules/routes/routeSystem.py index 04e14063..3c8cdd3d 100644 --- a/modules/routes/routeSystem.py +++ b/modules/routes/routeSystem.py @@ -409,7 +409,7 @@ def _formatBlockItem(item: Dict[str, Any], language: str) -> Dict[str, Any]: @navigationRouter.get("/navigation") @limiter.limit("60/minute") -async def get_navigation( +def get_navigation( request: Request, language: str = Query("de", description="Language for labels (en, de, fr)"), reqContext: RequestContext = Depends(getRequestContext) diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index e63f7932..172fc977 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -177,17 +177,14 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: workflow = services.interfaceDbChat.updateWorkflow(workflow.id, {"name": workflowName}) logger.info(f"Set workflow {workflow.id} name to: {workflowName}") - # Update automation with execution log + # Save execution log (bypasses RBAC — system operation, not a user edit) executionLogs = list(automation.executionLogs or []) executionLogs.append(executionLog) # Keep only last 50 executions if len(executionLogs) > 50: executionLogs = executionLogs[-50:] - services.interfaceDbAutomation.updateAutomationDefinition( - automationId, - {"executionLogs": executionLogs} - ) + services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs) return workflow except Exception as e: @@ -195,7 +192,7 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: executionLog["status"] = "error" executionLog["messages"].append(f"Error: {str(e)}") - # Update automation with execution log even on error + # Save execution log even on error (bypasses RBAC — system operation) try: automation = services.interfaceDbAutomation.getAutomationDefinition(automationId) if automation: @@ -203,10 +200,7 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: executionLogs.append(executionLog) if len(executionLogs) > 50: executionLogs = executionLogs[-50:] - services.interfaceDbAutomation.updateAutomationDefinition( - automationId, - {"executionLogs": executionLogs} - ) + services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs) except Exception as logError: logger.error(f"Error saving execution log: {str(logError)}") diff --git a/scripts/migrate_async_to_sync.py b/scripts/migrate_async_to_sync.py new file mode 100644 index 00000000..d0f8ef67 --- /dev/null +++ b/scripts/migrate_async_to_sync.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 +""" +Migration Script: Convert async def → def for route handlers that don't need async. + +This fixes the event-loop blocking issue where synchronous psycopg2 DB operations +inside async def routes block the entire uvicorn event loop, preventing concurrent +request handling. + +FastAPI behavior: +- `async def` routes → run directly on the event loop (blocks if sync code inside) +- `def` routes → run in a thread pool automatically (non-blocking) + +Usage: + python scripts/migrate_async_to_sync.py --dry-run # Preview changes + python scripts/migrate_async_to_sync.py # Apply changes + +Author: Auto-generated migration script +""" + +import os +import re +import sys +import argparse +from pathlib import Path +from typing import Dict, List, Set, Tuple + +# Base directory +GATEWAY_DIR = Path(__file__).parent.parent +ROUTES_DIR = GATEWAY_DIR / "modules" / "routes" +FEATURES_DIR = GATEWAY_DIR / "modules" / "features" +AUTH_DIR = GATEWAY_DIR / "modules" / "auth" + + +# ============================================================================= +# Configuration: Functions that MUST stay async +# ============================================================================= + +# Key: relative file path from gateway dir +# Value: set of function names that must remain async def +_MUST_STAY_ASYNC: Dict[str, Set[str]] = { + # --- routes/ --- + "modules/routes/routeAdminAutomationEvents.py": { + "sync_all_automation_events", # await syncAutomationEvents(...) + }, + "modules/routes/routeAdminRbacExport.py": { + "import_global_rbac", # await file.read() + "import_mandate_rbac", # await file.read() + }, + "modules/routes/routeDataConnections.py": { + "get_connections", # await token_refresh_service.refresh_expired_tokens(...) + }, + "modules/routes/routeDataFiles.py": { + "upload_file", # await file.read() + }, + "modules/routes/routeDataWorkflows.py": { + "stop_workflow", # await chatStop(...) + }, + # These files have many genuinely async routes (httpx, external APIs) -- keep ALL async: + "modules/routes/routeRealEstate.py": "__ALL__", + "modules/routes/routeSharepoint.py": "__ALL__", + "modules/routes/routeVoiceGoogle.py": "__ALL__", + # Partial keeps in security routes (httpx.AsyncClient, request.json()): + "modules/routes/routeSecurityGoogle.py": { + "verify_google_token", # await client.get(...) + "auth_callback", # await verify_google_token(...), await client.get(...) + "verify_token", # await verify_google_token(...) + "refresh_token", # await request.json() + }, + "modules/routes/routeSecurityMsft.py": { + "auth_callback", # await client.get(...) + "refresh_token", # await request.json() + }, + # --- features/ --- + "modules/features/automation/routeFeatureAutomation.py": { + "execute_automation_route", # await executeAutomation(...) + }, + "modules/features/chatbot/routeFeatureChatbot.py": { + "stream_chatbot_start", # await chatProcess(...), contains async event_stream generator + "event_stream", # await request.is_disconnected(), await asyncio.wait_for(...) + "stop_chatbot", # await event_manager.emit_event(...) + }, + "modules/features/chatplayground/routeFeatureChatplayground.py": { + "start_workflow", # await chatStart(...) + "stop_workflow", # await chatStop(...) + }, + "modules/features/neutralization/routeFeatureNeutralizer.py": { + "process_sharepoint_files", # await service.processSharepointFiles(...) + }, + "modules/features/realestate/routeFeatureRealEstate.py": { + "process_command", # await processNaturalLanguageCommand(...) + "create_table_record", # await create_project_with_parcel_data(...) + "search_parcel", # await connector.search_parcel(...), connector._query_building_layer(...) + "add_parcel_to_project", # await connector.search_parcel(...) + }, + "modules/features/trustee/routeFeatureTrustee.py": { + "create_document", # await request.json() + "upload_document", # await file.read() + }, +} + +# Files to skip entirely (all routes must stay async) +_SKIP_FILES: Set[str] = { + "modules/routes/routeRealEstate.py", + "modules/routes/routeSharepoint.py", + "modules/routes/routeVoiceGoogle.py", +} + +# Helper functions that are fake-async (async def but no await inside) +# These will be converted from async def -> def +_FAKE_ASYNC_HELPERS: Dict[str, Set[str]] = { + "modules/features/chatplayground/routeFeatureChatplayground.py": {"_validateInstanceAccess"}, + "modules/features/trustee/routeFeatureTrustee.py": {"_validateInstanceAccess", "_validateInstanceAdmin"}, + "modules/features/realestate/routeFeatureRealEstate.py": {"_validateInstanceAccess"}, + "modules/features/chatbot/routeFeatureChatbot.py": {"_validateInstanceAccess"}, + "modules/routes/routeNotifications.py": {"_handleInvitationAction"}, +} + +# Calls to these functions should have 'await' removed after they become sync +_REMOVE_AWAIT_CALLS: Set[str] = { + "_validateInstanceAccess", + "_validateInstanceAdmin", + "_handleInvitationAction", +} + + +# ============================================================================= +# Migration Logic +# ============================================================================= + +def _getRelativePath(filePath: Path) -> str: + """Get path relative to gateway dir.""" + try: + return str(filePath.relative_to(GATEWAY_DIR)).replace("\\", "/") + except ValueError: + return str(filePath) + + +def _shouldSkipFile(relPath: str) -> bool: + """Check if file should be skipped entirely.""" + return relPath in _SKIP_FILES or _MUST_STAY_ASYNC.get(relPath) == "__ALL__" + + +def _mustStayAsync(relPath: str, funcName: str) -> bool: + """Check if a specific function must stay async.""" + keepSet = _MUST_STAY_ASYNC.get(relPath, set()) + if keepSet == "__ALL__": + return True + return funcName in keepSet + + +def _isFakeAsyncHelper(relPath: str, funcName: str) -> bool: + """Check if a function is a fake-async helper that should be converted.""" + helpers = _FAKE_ASYNC_HELPERS.get(relPath, set()) + return funcName in helpers + + +def _processFile(filePath: Path, dryRun: bool = True) -> Dict[str, any]: + """Process a single file and convert async def → def where appropriate.""" + relPath = _getRelativePath(filePath) + + if _shouldSkipFile(relPath): + return {"file": relPath, "skipped": True, "reason": "all routes must stay async"} + + with open(filePath, "r", encoding="utf-8") as f: + originalContent = f.read() + + content = originalContent + changes = [] + + # Step 1: Find all async def functions and convert eligible ones + # Pattern matches: async def function_name( + asyncDefPattern = re.compile(r'^(\s*)async def (\w+)\s*\(', re.MULTILINE) + + convertedFunctions = set() + + for match in asyncDefPattern.finditer(originalContent): + indent = match.group(1) + funcName = match.group(2) + + # Check if this function must stay async + if _mustStayAsync(relPath, funcName): + changes.append(f" KEEP async: {funcName} (must stay async)") + continue + + # Convert async def → def + convertedFunctions.add(funcName) + changes.append(f" CONVERT: async def {funcName} -> def {funcName}") + + # Apply conversions + for funcName in convertedFunctions: + # Replace "async def funcName(" with "def funcName(" + # Be careful to match the exact function definition + pattern = re.compile( + r'^(\s*)async def ' + re.escape(funcName) + r'\s*\(', + re.MULTILINE + ) + content = pattern.sub( + lambda m: f'{m.group(1)}def {funcName}(', + content + ) + + # Step 2: Remove 'await' from calls to converted functions + # This handles: await _validateInstanceAccess(...) → _validateInstanceAccess(...) + # And also: result = await someConvertedFunc(...) → result = someConvertedFunc(...) + for funcName in _REMOVE_AWAIT_CALLS: + if funcName in convertedFunctions or _isFakeAsyncHelper(relPath, funcName): + awaitPattern = re.compile( + r'(\s*)(.*)await\s+' + re.escape(funcName) + r'\s*\(', + re.MULTILINE + ) + newContent = awaitPattern.sub( + lambda m: f'{m.group(1)}{m.group(2)}{funcName}(', + content + ) + if newContent != content: + changes.append(f" REMOVE await: await {funcName}(...) -> {funcName}(...)") + content = newContent + + # Step 3: Check for any remaining 'await' in converted functions + # This catches cases where a converted function still has await calls + remainingAwaits = [] + lines = content.split('\n') + currentFunc = None + funcIndent = 0 + + for i, line in enumerate(lines): + # Track current function + defMatch = re.match(r'^(\s*)def (\w+)\s*\(', line) + asyncDefMatch = re.match(r'^(\s*)async def (\w+)\s*\(', line) + + if defMatch and defMatch.group(2) in convertedFunctions: + currentFunc = defMatch.group(2) + funcIndent = len(defMatch.group(1)) + elif defMatch or asyncDefMatch: + currentFunc = None + elif currentFunc and line.strip() and not line[0].isspace(): + currentFunc = None + + # Check for remaining awaits in converted functions + if currentFunc and 'await ' in line: + remainingAwaits.append(f" WARNING: Remaining 'await' in {currentFunc} at line {i+1}: {line.strip()}") + + # Build result + result = { + "file": relPath, + "skipped": False, + "convertedCount": len(convertedFunctions), + "convertedFunctions": sorted(convertedFunctions), + "changes": changes, + "warnings": remainingAwaits, + "modified": content != originalContent, + } + + # Write file if not dry run and content changed + if not dryRun and content != originalContent: + with open(filePath, "w", encoding="utf-8") as f: + f.write(content) + result["written"] = True + else: + result["written"] = False + + return result + + +def _discoverRouteFiles() -> List[Path]: + """Discover all route files to process.""" + files = [] + + # Standard routes + if ROUTES_DIR.exists(): + for f in sorted(ROUTES_DIR.glob("route*.py")): + files.append(f) + + # Feature routes + if FEATURES_DIR.exists(): + for f in sorted(FEATURES_DIR.glob("*/routeFeature*.py")): + files.append(f) + + return files + + +def _main(): + parser = argparse.ArgumentParser( + description="Migrate async def → def for FastAPI routes with sync DB operations" + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="Preview changes without writing files (default: apply changes)" + ) + parser.add_argument( + "--file", + type=str, + default=None, + help="Process only a specific file (relative to gateway dir)" + ) + args = parser.parse_args() + + dryRun = args.dry_run + + print("=" * 70) + print(f" FastAPI Route Migration: async def -> def") + print(f" Mode: {'DRY RUN (preview only)' if dryRun else 'APPLY CHANGES'}") + print("=" * 70) + print() + + # Discover files + if args.file: + targetFile = GATEWAY_DIR / args.file.replace("/", os.sep) + if not targetFile.exists(): + print(f"ERROR: File not found: {targetFile}") + sys.exit(1) + files = [targetFile] + else: + files = _discoverRouteFiles() + + print(f"Found {len(files)} route files to analyze\n") + + totalConverted = 0 + totalWarnings = 0 + totalModified = 0 + allResults = [] + + for filePath in files: + result = _processFile(filePath, dryRun=dryRun) + allResults.append(result) + + if result.get("skipped"): + print(f"[SKIP] {result['file']} - SKIPPED ({result.get('reason', '')})") + continue + + converted = result.get("convertedCount", 0) + warnings = result.get("warnings", []) + modified = result.get("modified", False) + + if converted == 0 and not warnings: + continue + + totalConverted += converted + totalWarnings += len(warnings) + if modified: + totalModified += 1 + + status = "[DONE] WRITTEN" if result.get("written") else ("[PLAN] WOULD WRITE" if modified else "---") + print(f"{status} {result['file']} ({converted} functions)") + + for change in result.get("changes", []): + print(f" {change}") + + for warning in warnings: + print(f" [WARN] {warning}") + + print() + + # Summary + print("=" * 70) + print(f" SUMMARY") + print(f" Files analyzed: {len(files)}") + print(f" Files modified: {totalModified}") + print(f" Functions converted: {totalConverted}") + print(f" Warnings: {totalWarnings}") + if dryRun: + print(f"\n This was a DRY RUN. Run without --dry-run to apply changes.") + else: + print(f"\n Changes applied. Restart the server to take effect.") + print("=" * 70) + + # Return exit code based on warnings + if totalWarnings > 0: + print(f"\n[WARN] There are {totalWarnings} warnings - review before deploying!") + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(_main()) From 887867acd0477d09b72b5a47ac5a59ab9e2778c5 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Sun, 8 Feb 2026 16:14:01 +0100 Subject: [PATCH 14/18] billing rbac --- modules/aicore/aicorePluginPrivateLlm.py | 9 +- modules/datamodels/datamodelBilling.py | 4 + .../automation/routeFeatureAutomation.py | 4 + modules/features/chatbot/service.py | 8 +- .../routeFeatureChatplayground.py | 3 +- modules/interfaces/interfaceAiObjects.py | 25 +- modules/interfaces/interfaceDbBilling.py | 60 +++- modules/routes/routeBilling.py | 270 ++++++++++++++---- modules/services/serviceAi/mainServiceAi.py | 207 ++++++++++---- .../serviceBilling/mainServiceBilling.py | 18 +- .../services/serviceChat/mainServiceChat.py | 66 +---- modules/workflows/automation/mainWorkflow.py | 29 +- 12 files changed, 506 insertions(+), 197 deletions(-) diff --git a/modules/aicore/aicorePluginPrivateLlm.py b/modules/aicore/aicorePluginPrivateLlm.py index dfcb7eaf..718c5905 100644 --- a/modules/aicore/aicorePluginPrivateLlm.py +++ b/modules/aicore/aicorePluginPrivateLlm.py @@ -241,8 +241,10 @@ class AiPrivateLlm(BaseConnectorAi): priority=PriorityEnum.COST, processingMode=ProcessingModeEnum.BASIC, operationTypes=createOperationTypeRatings( - (OperationTypeEnum.DATA_EXTRACT, 9), - (OperationTypeEnum.DATA_ANALYSE, 9), + (OperationTypeEnum.PLAN, 7), + (OperationTypeEnum.DATA_ANALYSE, 8), + (OperationTypeEnum.DATA_GENERATE, 8), + (OperationTypeEnum.DATA_EXTRACT, 8), ), version="qwen2.5:7b", calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_TEXT_PER_CALL @@ -268,7 +270,6 @@ class AiPrivateLlm(BaseConnectorAi): processingMode=ProcessingModeEnum.ADVANCED, operationTypes=createOperationTypeRatings( (OperationTypeEnum.IMAGE_ANALYSE, 9), - (OperationTypeEnum.DATA_EXTRACT, 8), ), version="qwen2.5vl:7b", calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL @@ -294,8 +295,6 @@ class AiPrivateLlm(BaseConnectorAi): processingMode=ProcessingModeEnum.DETAILED, operationTypes=createOperationTypeRatings( (OperationTypeEnum.IMAGE_ANALYSE, 9), - (OperationTypeEnum.DATA_EXTRACT, 9), - (OperationTypeEnum.DATA_ANALYSE, 8), ), version="granite3.2-vision", calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL diff --git a/modules/datamodels/datamodelBilling.py b/modules/datamodels/datamodelBilling.py index e7e59eb4..ea5a4c00 100644 --- a/modules/datamodels/datamodelBilling.py +++ b/modules/datamodels/datamodelBilling.py @@ -121,6 +121,8 @@ class BillingTransaction(BaseModel): featureInstanceId: Optional[str] = Field(None, description="Feature instance ID") featureCode: Optional[str] = Field(None, description="Feature code (e.g., chatplayground, automation)") aicoreProvider: Optional[str] = Field(None, description="AICore provider (anthropic, openai, etc.)") + aicoreModel: Optional[str] = Field(None, description="AICore model name (e.g., claude-4-sonnet, gpt-4o)") + createdByUserId: Optional[str] = Field(None, description="User who created/caused this transaction") registerModelLabels( @@ -138,6 +140,8 @@ registerModelLabels( "featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"}, "featureCode": {"en": "Feature Code", "de": "Feature-Code"}, "aicoreProvider": {"en": "AI Provider", "de": "AI-Anbieter"}, + "aicoreModel": {"en": "AI Model", "de": "AI-Modell"}, + "createdByUserId": {"en": "Created By User", "de": "Erstellt von Benutzer"}, }, ) diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py index d39a3358..8a5fee1d 100644 --- a/modules/features/automation/routeFeatureAutomation.py +++ b/modules/features/automation/routeFeatureAutomation.py @@ -367,6 +367,10 @@ async def execute_automation_route( try: from modules.services import getInterface as getServices services = getServices(context.user, context.mandateId) + # Propagate feature context for billing + if context.featureInstanceId: + services.featureInstanceId = str(context.featureInstanceId) + services.featureCode = 'automation' workflow = await executeAutomation(automationId, services) return workflow except HTTPException: diff --git a/modules/features/chatbot/service.py b/modules/features/chatbot/service.py index fd7f3344..3afd1632 100644 --- a/modules/features/chatbot/service.py +++ b/modules/features/chatbot/service.py @@ -93,6 +93,12 @@ async def chatProcess( try: # Get services with mandate context services = getServices(currentUser, mandateId) + + # Set feature context for billing + if featureInstanceId: + services.featureInstanceId = featureInstanceId + services.featureCode = 'chatbot' + interfaceDbChat = services.interfaceDbChat # Get event manager and create queue if needed @@ -698,7 +704,7 @@ async def _convert_file_ids_to_document_references( # Search database if not found in messages if not document_id: try: - from modules.shared.databaseUtils import getRecordsetWithRBAC + from modules.interfaces.interfaceRbac import getRecordsetWithRBAC documents = getRecordsetWithRBAC( services.interfaceDbChat.db, ChatDocument, diff --git a/modules/features/chatplayground/routeFeatureChatplayground.py b/modules/features/chatplayground/routeFeatureChatplayground.py index ce1611ea..e3787904 100644 --- a/modules/features/chatplayground/routeFeatureChatplayground.py +++ b/modules/features/chatplayground/routeFeatureChatplayground.py @@ -102,7 +102,8 @@ async def start_workflow( workflowMode, workflowId, mandateId=mandateId, - featureInstanceId=instanceId + featureInstanceId=instanceId, + featureCode='chatplayground' ) return workflow diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py index 3976350d..0214d231 100644 --- a/modules/interfaces/interfaceAiObjects.py +++ b/modules/interfaces/interfaceAiObjects.py @@ -4,8 +4,8 @@ import logging import asyncio import uuid import base64 -from typing import Dict, Any, List, Union, Tuple, Optional -from dataclasses import dataclass +from typing import Dict, Any, List, Union, Tuple, Optional, Callable +from dataclasses import dataclass, field import time logger = logging.getLogger(__name__) @@ -29,7 +29,13 @@ from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy @dataclass(slots=True) class AiObjects: - """Centralized AI interface: dynamically discovers and uses AI models. Includes web functionality.""" + """Centralized AI interface: dynamically discovers and uses AI models. + + billingCallback: Set by serviceAi before AI calls. Called after EVERY individual + model call with the AiCallResponse. This ensures per-model-call billing with + exact provider + model name. The callback handles billing recording. + """ + billingCallback: Optional[Callable] = field(default=None, repr=False) def __post_init__(self) -> None: # Auto-discover and register all available connectors @@ -226,7 +232,7 @@ class AiObjects: # Calculate price using model's own price calculation method priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) - return AiCallResponse( + response = AiCallResponse( content=content, modelName=model.name, provider=model.connectorType, @@ -236,6 +242,17 @@ class AiObjects: bytesReceived=outputBytes, errorCount=0 ) + + # BILLING: Record billing for THIS specific model call + # billingCallback is set by serviceAi and records one billing transaction + # per model call with exact provider + model name + if self.billingCallback: + try: + self.billingCallback(response) + except Exception as e: + logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}") + + return response # Utility methods diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index ae8b13ec..a8cbd61b 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -710,6 +710,7 @@ class BillingObjects: featureInstanceId: str = None, featureCode: str = None, aicoreProvider: str = None, + aicoreModel: str = None, description: str = "AI Usage" ) -> Optional[Dict[str, Any]]: """ @@ -722,7 +723,8 @@ class BillingObjects: workflowId: Optional workflow ID featureInstanceId: Optional feature instance ID featureCode: Optional feature code - aicoreProvider: Optional AICore provider name + aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai') + aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o') description: Transaction description Returns: @@ -758,7 +760,9 @@ class BillingObjects: workflowId=workflowId, featureInstanceId=featureInstanceId, featureCode=featureCode, - aicoreProvider=aicoreProvider + aicoreProvider=aicoreProvider, + aicoreModel=aicoreModel, + createdByUserId=userId ) return self.createTransaction(transaction) @@ -828,9 +832,13 @@ class BillingObjects: # Calculate by provider costByProvider = {} + costByModel = {} for t in debits: provider = t.get("aicoreProvider", "unknown") costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0) + + model = t.get("aicoreModel", "unknown") + costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0) # Calculate by feature costByFeature = {} @@ -842,6 +850,7 @@ class BillingObjects: "totalCostCHF": totalCost, "transactionCount": len(debits), "costByProvider": costByProvider, + "costByModel": costByModel, "costByFeature": costByFeature } @@ -1129,8 +1138,8 @@ class BillingObjects: user = appInterface.getUser(userId) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) - email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None) - userMap[userId] = displayName or email or userId + username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) + userMap[userId] = displayName or username or userId # Get mandate info efficiently mandateMap = {} @@ -1212,8 +1221,8 @@ class BillingObjects: user = appInterface.getUser(userId) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) - email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None) - userMap[userId] = displayName or email or userId + username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) + userMap[userId] = displayName or username or userId # Get mandate info efficiently mandateMap = {} @@ -1224,7 +1233,8 @@ class BillingObjects: mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") mandateMap[mandateId] = mandateName - # Get transactions for all accounts + # Get transactions for all accounts and collect createdByUserIds + rawTransactions = [] for account in allAccounts: accountId = account.get("id") if not accountId: @@ -1233,14 +1243,38 @@ class BillingObjects: transactions = self.getTransactions(accountId, limit=limit) accountInfo = accountMap.get(accountId, {}) mandateId = accountInfo.get("mandateId") - userId = accountInfo.get("userId") + accountUserId = accountInfo.get("userId") for t in transactions: - t["mandateId"] = mandateId - t["mandateName"] = mandateMap.get(mandateId, "") - t["userId"] = userId - t["userName"] = userMap.get(userId, userId) - allTransactions.append(t) + t["_accountUserId"] = accountUserId + t["_accountMandateId"] = mandateId + rawTransactions.append(t) + + # Resolve createdByUserIds that are not yet in userMap + extraUserIds = set() + for t in rawTransactions: + cbUserId = t.get("createdByUserId") + if cbUserId and cbUserId not in userMap: + extraUserIds.add(cbUserId) + + for uid in extraUserIds: + user = appInterface.getUser(uid) + if user: + displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) + username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) + userMap[uid] = displayName or username or uid + + # Enrich transactions + for t in rawTransactions: + mandateId = t.pop("_accountMandateId", None) + accountUserId = t.pop("_accountUserId", None) + t["mandateId"] = mandateId + t["mandateName"] = mandateMap.get(mandateId, "") + # Prefer createdByUserId (per-transaction) over account-derived userId + txUserId = t.get("createdByUserId") or accountUserId + t["userId"] = txUserId + t["userName"] = userMap.get(txUserId, txUserId) if txUserId else None + allTransactions.append(t) except Exception as e: logger.error(f"Error getting user transactions for mandates: {e}") diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index 26133704..3e87cd83 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -42,6 +42,143 @@ from modules.datamodels.datamodelBilling import ( # Configure logger logger = logging.getLogger(__name__) + +# ============================================================================= +# Billing RBAC Data Scope +# ============================================================================= +# +# RBAC rules for billing data visibility: +# +# SysAdmin → ALL transactions and statistics across all mandates +# Mandate-Admin → ALL user data within their administrated mandates +# Feature-Instance-Admin→ Data for their administrated feature instances +# Regular User → ONLY their own data within their mandates +# + +class BillingDataScope: + """ + Determines what billing data a user can see based on RBAC roles. + + Evaluated once per request and used to filter transactions/statistics. + """ + __slots__ = ('isGlobalAdmin', 'adminMandateIds', 'adminFeatureInstanceIds', + 'memberMandateIds', 'userId') + + def __init__(self, userId: str): + self.isGlobalAdmin: bool = False + self.adminMandateIds: list = [] + self.adminFeatureInstanceIds: list = [] + self.memberMandateIds: list = [] + self.userId: str = userId + + +def _getBillingDataScope(user) -> BillingDataScope: + """ + Determine what billing data a user can see based on RBAC. + + Uses rootInterface (privileged) to check roles across all mandates + and feature instances without RBAC restrictions on the lookup itself. + + Returns: + BillingDataScope with the user's visibility boundaries. + """ + scope = BillingDataScope(userId=user.id) + + if user.isSysAdmin: + scope.isGlobalAdmin = True + return scope + + from modules.interfaces.interfaceDbApp import getRootInterface + rootInterface = getRootInterface() + + # --- Mandate roles --- + userMandates = rootInterface.getUserMandates(user.id) + for um in userMandates: + mandateId = getattr(um, 'mandateId', None) + umId = getattr(um, 'id', None) + if not mandateId or not umId: + continue + + roleIds = rootInterface.getRoleIdsForUserMandate(umId) + isAdmin = False + for roleId in roleIds: + role = rootInterface.getRole(roleId) + if role and role.roleLabel == "admin" and not role.featureInstanceId: + isAdmin = True + break + + if isAdmin: + scope.adminMandateIds.append(mandateId) + else: + scope.memberMandateIds.append(mandateId) + + # --- Feature instance roles --- + featureAccesses = rootInterface.getFeatureAccessesForUser(user.id) + for fa in featureAccesses: + fiId = getattr(fa, 'featureInstanceId', None) + faId = getattr(fa, 'id', None) + if not fiId or not faId: + continue + + roleIds = rootInterface.getRoleIdsForFeatureAccess(faId) + for roleId in roleIds: + role = rootInterface.getRole(roleId) + if role and role.roleLabel == "admin": + scope.adminFeatureInstanceIds.append(fiId) + break + + logger.debug( + f"BillingDataScope for user {user.id}: " + f"globalAdmin={scope.isGlobalAdmin}, " + f"adminMandates={scope.adminMandateIds}, " + f"adminInstances={scope.adminFeatureInstanceIds}, " + f"memberMandates={scope.memberMandateIds}" + ) + return scope + + +def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> list: + """ + Filter a list of transaction dicts based on the user's BillingDataScope. + + Rules: + - SysAdmin: no filter + - Mandate-Admin: all transactions in their admin mandates + - Feature-Instance-Admin: transactions for their admin feature instances + - Regular user: only transactions where createdByUserId/userId matches + """ + if scope.isGlobalAdmin: + return transactions + + adminMandateSet = set(scope.adminMandateIds) + adminFiSet = set(scope.adminFeatureInstanceIds) + memberMandateSet = set(scope.memberMandateIds) + + result = [] + for t in transactions: + mandateId = t.get("mandateId") + fiId = t.get("featureInstanceId") + txUserId = t.get("createdByUserId") or t.get("userId") + + # Mandate admin → sees all transactions in their mandate + if mandateId and mandateId in adminMandateSet: + result.append(t) + continue + + # Feature instance admin → sees all transactions for their instances + if fiId and fiId in adminFiSet: + result.append(t) + continue + + # Regular member → only own transactions + if mandateId and mandateId in memberMandateSet: + if txUserId and txUserId == scope.userId: + result.append(t) + continue + + return result + + # ============================================================================= # Request/Response Models # ============================================================================= @@ -74,7 +211,10 @@ class TransactionResponse(BaseModel): referenceType: Optional[ReferenceTypeEnum] workflowId: Optional[str] featureCode: Optional[str] + featureInstanceId: Optional[str] = None aicoreProvider: Optional[str] + aicoreModel: Optional[str] = None + createdByUserId: Optional[str] = None createdAt: Optional[datetime] mandateId: Optional[str] = None mandateName: Optional[str] = None @@ -98,6 +238,7 @@ class UsageReportResponse(BaseModel): totalCost: float transactionCount: int costByProvider: Dict[str, float] + costByModel: Dict[str, float] = {} costByFeature: Dict[str, float] @@ -140,7 +281,10 @@ class UserTransactionResponse(BaseModel): referenceType: Optional[ReferenceTypeEnum] workflowId: Optional[str] featureCode: Optional[str] + featureInstanceId: Optional[str] = None aicoreProvider: Optional[str] + aicoreModel: Optional[str] = None + createdByUserId: Optional[str] = None createdAt: Optional[datetime] mandateId: Optional[str] = None mandateName: Optional[str] = None @@ -261,7 +405,10 @@ def getTransactions( referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, workflowId=t.get("workflowId"), featureCode=t.get("featureCode"), + featureInstanceId=t.get("featureInstanceId"), aicoreProvider=t.get("aicoreProvider"), + aicoreModel=t.get("aicoreModel"), + createdByUserId=t.get("createdByUserId"), createdAt=t.get("_createdAt"), mandateId=t.get("mandateId"), mandateName=t.get("mandateName") @@ -349,6 +496,7 @@ def getStatistics( totalCost=stats.get("totalCostCHF", 0.0), transactionCount=stats.get("transactionCount", 0), costByProvider=stats.get("costByProvider", {}), + costByModel=stats.get("costByModel", {}), costByFeature=stats.get("costByFeature", {}) ) @@ -564,6 +712,7 @@ def getAccounts( class MandateUserSummary(BaseModel): """Summary of a user for billing admin purposes.""" id: str + username: Optional[str] = None email: Optional[str] = None firstName: Optional[str] = None lastName: Optional[str] = None @@ -600,18 +749,21 @@ def getUsersForMandate( # Handle both Pydantic models and dicts if isinstance(user, dict): + username = user.get("username", "") firstName = user.get("firstName", "") lastName = user.get("lastName", "") email = user.get("email", "") else: + username = getattr(user, "username", "") or "" firstName = getattr(user, "firstName", "") or "" lastName = getattr(user, "lastName", "") or "" email = getattr(user, "email", "") or "" - displayName = f"{firstName} {lastName}".strip() or email or userId + displayName = f"{firstName} {lastName}".strip() or username or userId result.append(MandateUserSummary( id=userId, + username=username, email=email, firstName=firstName, lastName=lastName, @@ -652,7 +804,10 @@ def getTransactionsAdmin( referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, workflowId=t.get("workflowId"), featureCode=t.get("featureCode"), + featureInstanceId=t.get("featureInstanceId"), aicoreProvider=t.get("aicoreProvider"), + aicoreModel=t.get("aicoreModel"), + createdByUserId=t.get("createdByUserId"), createdAt=t.get("_createdAt") )) @@ -715,7 +870,10 @@ def getMandateViewTransactions( referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, workflowId=t.get("workflowId"), featureCode=t.get("featureCode"), + featureInstanceId=t.get("featureInstanceId"), aicoreProvider=t.get("aicoreProvider"), + aicoreModel=t.get("aicoreModel"), + createdByUserId=t.get("createdByUserId"), createdAt=t.get("_createdAt"), mandateId=t.get("mandateId"), mandateName=t.get("mandateName") @@ -740,39 +898,35 @@ def getUserViewBalances( ): """ Get user-level balances. + + RBAC filtering: - SysAdmin: sees all user balances across all mandates - - MandateAdmin: sees user balances for mandates they manage + - Mandate-Admin: sees user balances for mandates they administrate - Regular user: sees only their own balances """ try: billingInterface = getBillingInterface(ctx.user, ctx.mandateId) - # Determine which mandates the user has access to - if ctx.user.isSysAdmin: - # SysAdmin sees all + # Evaluate RBAC scope + scope = _getBillingDataScope(ctx.user) + + # Determine mandate IDs for data loading + if scope.isGlobalAdmin: mandateIds = None else: - # Get mandates where user is admin or has billing access - from modules.interfaces.interfaceDbApp import getInterface as getAppInterface - appInterface = getAppInterface(ctx.user) - userMandates = appInterface.getUserMandates(ctx.user.id) - - # Filter to only mandates where user has admin role - # For simplicity, we'll check if user is admin in any mandate - mandateIds = [] - for um in userMandates: - mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) - if mandateId: - mandateIds.append(mandateId) - + mandateIds = scope.adminMandateIds + scope.memberMandateIds if not mandateIds: return [] allBalances = billingInterface.getUserBalancesForMandates(mandateIds) - # Non-admin users only see their own balances - if not ctx.user.isSysAdmin: - allBalances = [b for b in allBalances if b.get("userId") == ctx.user.id] + # RBAC filter: mandate admins see all in their mandates, regular users only own + if not scope.isGlobalAdmin: + adminMandateSet = set(scope.adminMandateIds) + allBalances = [ + b for b in allBalances + if b.get("mandateId") in adminMandateSet or b.get("userId") == scope.userId + ] return [UserBalanceResponse(**b) for b in allBalances] @@ -786,6 +940,7 @@ class ViewStatisticsResponse(BaseModel): totalCost: float = 0.0 transactionCount: int = 0 costByProvider: Dict[str, float] = {} + costByModel: Dict[str, float] = {} costByFeature: Dict[str, float] = {} costByMandate: Dict[str, float] = {} timeSeries: List[Dict[str, Any]] = [] @@ -802,6 +957,13 @@ def getUserViewStatistics( ) -> ViewStatisticsResponse: """ Get aggregated usage statistics across all user's mandates. + + RBAC filtering: + - SysAdmin: statistics across all mandates + - Mandate-Admin: statistics for mandates they administrate + - Feature-Instance-Admin: statistics for their feature instances + - Regular user: only their own usage statistics + - period='month': returns monthly time series for the given year - period='day': returns daily time series for the given month/year """ @@ -816,25 +978,25 @@ def getUserViewStatistics( billingInterface = getBillingInterface(ctx.user, ctx.mandateId) - # Get all mandates the user has access to - if ctx.user.isSysAdmin: + # Evaluate RBAC scope + scope = _getBillingDataScope(ctx.user) + + # Determine mandate IDs for data loading + if scope.isGlobalAdmin: mandateIds = None else: - from modules.interfaces.interfaceDbApp import getInterface as getAppInterface - appInterface = getAppInterface(ctx.user) - userMandates = appInterface.getUserMandates(ctx.user.id) - mandateIds = [] - for um in userMandates: - mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) - if mandateId: - mandateIds.append(mandateId) + mandateIds = scope.adminMandateIds + scope.memberMandateIds if not mandateIds: logger.warning("No mandate IDs found for user") return ViewStatisticsResponse() # Get all transactions allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000) - logger.info(f"View statistics: {len(allTransactions)} total transactions fetched for period={period}, year={year}, month={month}") + + # Apply RBAC filter + allTransactions = _filterTransactionsByScope(allTransactions, scope) + + logger.info(f"View statistics: {len(allTransactions)} RBAC-filtered transactions for period={period}, year={year}, month={month}") # Calculate date range if period == "day": @@ -905,6 +1067,7 @@ def getUserViewStatistics( totalCost = sum(t.get("amount", 0) for t in debits) costByProvider: Dict[str, float] = {} + costByModel: Dict[str, float] = {} costByFeature: Dict[str, float] = {} costByMandate: Dict[str, float] = {} @@ -912,6 +1075,9 @@ def getUserViewStatistics( provider = t.get("aicoreProvider") or "unknown" costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0) + model = t.get("aicoreModel") or "unknown" + costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0) + mandate = t.get("mandateName") or t.get("mandateId") or "unknown" featureCode = t.get("featureCode") or "unknown" featureKey = f"{mandate} / {featureCode}" @@ -950,6 +1116,7 @@ def getUserViewStatistics( totalCost=round(totalCost, 4), transactionCount=len(debits), costByProvider=costByProvider, + costByModel=costByModel, costByFeature=costByFeature, costByMandate=costByMandate, timeSeries=timeSeries @@ -969,8 +1136,11 @@ def getUserViewTransactions( ) -> PaginatedResponse[UserTransactionResponse]: """ Get user-level transactions with pagination support. + + RBAC filtering: - SysAdmin: sees all user transactions across all mandates - - MandateAdmin: sees user transactions for mandates they manage + - Mandate-Admin: sees all user transactions for mandates they administrate + - Feature-Instance-Admin: sees transactions for their feature instances - Regular user: sees only their own transactions Query Parameters: @@ -987,28 +1157,24 @@ def getUserViewTransactions( paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) - # Determine which mandates the user has access to - if ctx.user.isSysAdmin: - # SysAdmin sees all - mandateIds = None + # Evaluate RBAC scope + scope = _getBillingDataScope(ctx.user) + + # Determine mandate IDs for data loading + if scope.isGlobalAdmin: + mandateIds = None # Load all else: - # Get mandates where user has access - from modules.interfaces.interfaceDbApp import getInterface as getAppInterface - appInterface = getAppInterface(ctx.user) - userMandates = appInterface.getUserMandates(ctx.user.id) - - mandateIds = [] - for um in userMandates: - mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) - if mandateId: - mandateIds.append(mandateId) - + # Load data for all mandates the user belongs to (admin + member) + mandateIds = scope.adminMandateIds + scope.memberMandateIds if not mandateIds: return PaginatedResponse(items=[], pagination=None) allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000) - logger.debug(f"Found {len(allTransactions)} transactions for mandates {mandateIds}") + # Apply RBAC filter + allTransactions = _filterTransactionsByScope(allTransactions, scope) + + logger.debug(f"RBAC-filtered {len(allTransactions)} transactions for user {ctx.user.id}") # Convert to response objects as dicts for filtering/sorting transactionDicts = [] @@ -1022,7 +1188,10 @@ def getUserViewTransactions( "referenceType": t.get("referenceType"), "workflowId": t.get("workflowId"), "featureCode": t.get("featureCode"), + "featureInstanceId": t.get("featureInstanceId"), "aicoreProvider": t.get("aicoreProvider"), + "aicoreModel": t.get("aicoreModel"), + "createdByUserId": t.get("createdByUserId"), "createdAt": t.get("_createdAt"), "mandateId": t.get("mandateId"), "mandateName": t.get("mandateName"), @@ -1044,7 +1213,10 @@ def getUserViewTransactions( referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None, workflowId=d.get("workflowId"), featureCode=d.get("featureCode"), + featureInstanceId=d.get("featureInstanceId"), aicoreProvider=d.get("aicoreProvider"), + aicoreModel=d.get("aicoreModel"), + createdByUserId=d.get("createdByUserId"), createdAt=d.get("createdAt"), mandateId=d.get("mandateId"), mandateName=d.get("mandateName"), diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py index 5fdf32a5..bd3ef6b4 100644 --- a/modules/services/serviceAi/mainServiceAi.py +++ b/modules/services/serviceAi/mainServiceAi.py @@ -21,7 +21,8 @@ from modules.datamodels.datamodelAi import JsonAccumulationState from modules.services.serviceBilling.mainServiceBilling import ( getService as getBillingService, InsufficientBalanceException, - ProviderNotAllowedException + ProviderNotAllowedException, + BillingContextError ) logger = logging.getLogger(__name__) @@ -88,41 +89,104 @@ class AiService: async def callAi(self, request: AiCallRequest, progressCallback=None): """Router: handles content parts via extractionService, text context via interface. - Replaces direct calls to self.aiObjects.call() to route content parts processing - through serviceExtraction layer. - - Includes billing checks: - - Balance check before AI call - - Provider permission check (via RBAC) - - Also stores workflow stats after each successful AI call. + FAIL-SAFE BILLING at the source: + 1. Pre-flight check: validates billing context is complete (RAISES if not) + 2. Balance & provider check before AI call + 3. billingCallback on aiObjects: records one billing transaction per model call + with exact provider + model name (set before AI call, invoked by _callWithModel) """ - # Billing check before AI call (validates RBAC permissions) + # FAIL-SAFE: Pre-flight billing validation (like 0 CHF credit card check) + self._preflightBillingCheck() + + # Balance & provider permission checks await self._checkBillingBeforeAiCall() # Calculate effective allowedProviders: RBAC ∩ Workflow - # RBAC is master - only RBAC-permitted providers can ever be used effectiveProviders = self._calculateEffectiveProviders() if effectiveProviders and request.options: request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders}) logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}") - if hasattr(request, 'contentParts') and request.contentParts: - response = await self.extractionService.processContentPartsWithAi( - request, self.aiObjects, progressCallback - ) - else: - response = await self.aiObjects.callWithTextContext(request) + # Set billing callback on aiObjects BEFORE the AI call + # This callback is invoked by _callWithModel() after EVERY individual model call + # For parallel content parts (e.g., 200 MB doc), each model call creates its own transaction + self.aiObjects.billingCallback = self._createBillingCallback() - # Store workflow stats after each AI call + try: + if hasattr(request, 'contentParts') and request.contentParts: + response = await self.extractionService.processContentPartsWithAi( + request, self.aiObjects, progressCallback + ) + else: + response = await self.aiObjects.callWithTextContext(request) + finally: + # Clear callback after call completes + self.aiObjects.billingCallback = None + + # Store workflow stats for analytics self._storeAiCallStats(response, request) return response + def _preflightBillingCheck(self) -> None: + """ + Pre-flight billing validation - like a 0 CHF credit card authorization check. + + Validates that ALL required billing context is present and that a billing + transaction CAN be recorded. This dry-run check catches missing context + BEFORE an expensive AI call starts. + + FAIL-SAFE: This method RAISES if billing context is incomplete. + An AI call without billing context MUST NOT proceed. + + Raises: + BillingContextError: If billing context is incomplete or invalid + """ + if not self.services: + raise BillingContextError("No service context available - cannot bill AI call") + + user = getattr(self.services, 'user', None) + if not user: + raise BillingContextError("No user context - cannot bill AI call") + + mandateId = getattr(self.services, 'mandateId', None) + if not mandateId: + raise BillingContextError( + f"No mandateId in service context for user {user.id} - cannot bill AI call. " + "Every AI call MUST have a mandate context for billing." + ) + + # Validate billing service can be created + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + + try: + billingService = getBillingService(user, mandateId, featureInstanceId, featureCode) + except Exception as e: + raise BillingContextError( + f"Cannot create billing service for user {user.id}, mandate {mandateId}: {e}" + ) + + # Dry-run: verify billing service can check balance (DB accessible) + try: + billingService.checkBalance(0.0) + except Exception as e: + raise BillingContextError( + f"Billing system not accessible for mandate {mandateId}: {e}" + ) + + logger.debug( + f"Pre-flight billing check PASSED: user={user.id}, mandate={mandateId}, " + f"feature={featureCode or 'none'}, instance={featureInstanceId or 'none'}" + ) + async def _checkBillingBeforeAiCall(self) -> None: """ Check billing status before making an AI call. + FAIL-SAFE: Context validation is done in _preflightBillingCheck() which is + called first. This method handles balance and provider permission checks. + Verifies: 1. User has sufficient balance (for prepay models) 2. Provider is allowed for the user (via RBAC) @@ -130,34 +194,19 @@ class AiService: Raises: InsufficientBalanceException: If balance is insufficient ProviderNotAllowedException: If provider is not allowed + BillingContextError: If billing check fails unexpectedly """ + # Context is already validated by _preflightBillingCheck() + user = self.services.user + mandateId = self.services.mandateId + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + try: - # Get context from services - if not self.services: - logger.debug("No service center - skipping billing check") - return - - user = getattr(self.services, 'user', None) - mandateId = getattr(self.services, 'mandateId', None) - - if not user or not mandateId: - logger.debug("No user or mandate context - skipping billing check") - return - - # Get feature context - featureInstanceId = getattr(self.services, 'featureInstanceId', None) - featureCode = getattr(self.services, 'featureCode', None) - # Get billing service - billingService = getBillingService( - user, - mandateId, - featureInstanceId, - featureCode - ) + billingService = getBillingService(user, mandateId, featureInstanceId, featureCode) # Check balance (estimate typical AI call cost) - # We use a small estimate here; actual cost is recorded after the call estimatedCost = 0.01 # ~1 cent CHF minimum balanceCheck = billingService.checkBalance(estimatedCost) @@ -170,7 +219,7 @@ class AiService: raise InsufficientBalanceException( currentBalance=balanceCheck.currentBalance or 0.0, requiredAmount=estimatedCost, - message=f"Ungenügendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}" + message=f"Ungenugendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}" ) logger.debug(f"Billing check passed: Balance {balanceCheck.currentBalance:.2f} CHF") @@ -181,13 +230,12 @@ class AiService: logger.warning(f"No AI providers allowed for user {user.id} in mandate {mandateId}") raise ProviderNotAllowedException( provider="any", - message="Keine AI-Provider für Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator." + message="Keine AI-Provider fuer Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator." ) # Check automation-level allowedProviders restriction automationAllowedProviders = getattr(self.services, 'allowedProviders', None) if automationAllowedProviders: - # Filter by both RBAC and automation-level restrictions effectiveProviders = [p for p in automationAllowedProviders if p in rbacAllowedProviders] if not effectiveProviders: logger.warning(f"No providers available after automation restriction. " @@ -195,7 +243,7 @@ class AiService: f"RBAC allows: {rbacAllowedProviders}") raise ProviderNotAllowedException( provider="any", - message="Die konfigurierten AI-Provider dieser Automation sind für Ihre Rolle nicht freigegeben." + message="Die konfigurierten AI-Provider dieser Automation sind fuer Ihre Rolle nicht freigegeben." ) logger.debug(f"Automation provider check passed: {effectiveProviders}") @@ -207,19 +255,78 @@ class AiService: logger.warning(f"Preferred provider {provider} not allowed for user {user.id}") raise ProviderNotAllowedException( provider=provider, - message=f"Der gewählte Provider '{provider}' ist für Ihre Rolle nicht freigegeben." + message=f"Der gewaehlte Provider '{provider}' ist fuer Ihre Rolle nicht freigegeben." ) logger.debug(f"All preferred providers are allowed: {preferredProviders}") logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed") except InsufficientBalanceException: - raise # Re-raise billing exceptions + raise except ProviderNotAllowedException: - raise # Re-raise provider exceptions + raise + except BillingContextError: + raise except Exception as e: - # Log but don't block on billing check errors - logger.warning(f"Billing check failed with error (non-blocking): {e}") + # FAIL-SAFE: Don't silently swallow errors - log at ERROR level + logger.error(f"BILLING FAIL-SAFE: Billing check failed with unexpected error: {e}") + raise BillingContextError(f"Billing check failed: {e}") + + def _createBillingCallback(self): + """ + Create a billing callback for interfaceAiObjects._callWithModel(). + + Returns a function that records one billing transaction per individual model call. + Each transaction contains the exact provider name AND model name. + + For a 200 MB document processed with N parallel AI calls (possibly different models), + this creates N separate billing transactions - one per model call. + """ + user = self.services.user + mandateId = self.services.mandateId + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + + # Get workflow ID if available + workflowId = None + workflow = getattr(self.services, 'workflow', None) + if workflow and hasattr(workflow, 'id'): + workflowId = workflow.id + + billingService = getBillingService(user, mandateId, featureInstanceId, featureCode) + + def _billingCallback(response) -> None: + """Record billing for a single AI model call.""" + if not response or getattr(response, 'errorCount', 0) > 0: + return + + priceCHF = getattr(response, 'priceCHF', 0.0) + if not priceCHF or priceCHF <= 0: + return + + provider = getattr(response, 'provider', None) or 'unknown' + modelName = getattr(response, 'modelName', None) or 'unknown' + + try: + billingService.recordUsage( + priceCHF=priceCHF, + workflowId=workflowId, + aicoreProvider=provider, + aicoreModel=modelName, + description=f"AI: {modelName}" + ) + logger.debug( + f"Billed model call: {priceCHF:.4f} CHF, " + f"provider={provider}, model={modelName}, mandate={mandateId}" + ) + except Exception as e: + logger.error( + f"BILLING: Failed to record transaction! " + f"Cost={priceCHF:.4f} CHF, user={user.id}, mandate={mandateId}, " + f"provider={provider}, model={modelName}, error={e}" + ) + + return _billingCallback def _calculateEffectiveProviders(self) -> Optional[List[str]]: """ diff --git a/modules/services/serviceBilling/mainServiceBilling.py b/modules/services/serviceBilling/mainServiceBilling.py index 472e0b58..8bf1c2c4 100644 --- a/modules/services/serviceBilling/mainServiceBilling.py +++ b/modules/services/serviceBilling/mainServiceBilling.py @@ -192,6 +192,7 @@ class BillingService: priceCHF: float, workflowId: str = None, aicoreProvider: str = None, + aicoreModel: str = None, description: str = None ) -> Optional[Dict[str, Any]]: """ @@ -206,6 +207,7 @@ class BillingService: priceCHF: Base price from AI model (before markup) workflowId: Optional workflow ID aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai') + aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o') description: Optional description Returns: @@ -222,7 +224,7 @@ class BillingService: # Build description if not description: - description = f"AI Usage: {aicoreProvider or 'unknown'}" + description = f"AI Usage: {aicoreModel or aicoreProvider or 'unknown'}" return self._billingInterface.recordUsage( mandateId=self.mandateId, @@ -232,6 +234,7 @@ class BillingService: featureInstanceId=self.featureInstanceId, featureCode=self.featureCode, aicoreProvider=aicoreProvider, + aicoreModel=aicoreModel, description=description ) @@ -399,3 +402,16 @@ class ProviderNotAllowedException(Exception): self.provider = provider self.message = message or f"Provider '{provider}' is not allowed for your role" super().__init__(self.message) + + +class BillingContextError(Exception): + """Raised when billing context is incomplete (missing mandateId, user, etc.). + + This is a FAIL-SAFE error: AI calls MUST NOT proceed without valid billing context. + Acts like a 0 CHF credit card pre-authorization check - validates that billing + CAN be recorded before any expensive AI operation starts. + """ + + def __init__(self, message: str = None): + self.message = message or "Billing context incomplete - AI call blocked" + super().__init__(self.message) diff --git a/modules/services/serviceChat/mainServiceChat.py b/modules/services/serviceChat/mainServiceChat.py index 055e34cd..b1e4a7ae 100644 --- a/modules/services/serviceChat/mainServiceChat.py +++ b/modules/services/serviceChat/mainServiceChat.py @@ -675,9 +675,11 @@ class ChatService: def storeWorkflowStat(self, workflow: Any, aiResponse: Any, process: str) -> ChatStat: """Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list. - Also records the usage cost to the billing system if configured.""" + + Billing is handled at the AI call source (interfaceAiObjects._callWithModel) + via billingCallback - not here. This method only handles workflow stats. + """ try: - # Create ChatStat from AiCallResponse data statData = { "workflowId": workflow.id, "process": process, @@ -689,76 +691,16 @@ class ChatService: "errorCount": aiResponse.errorCount } - # Create the stat record in the database stat = self.interfaceDbChat.createStat(statData) - # Append to workflow stats list in memory if not hasattr(workflow, 'stats') or workflow.stats is None: workflow.stats = [] workflow.stats.append(stat) - # Record billing transaction if mandateId is available - self._recordBillingUsage(workflow, aiResponse, process) - return stat except Exception as e: logger.error(f"Failed to store workflow stat: {e}") raise - - def _recordBillingUsage(self, workflow: Any, aiResponse: Any, process: str) -> None: - """Record AI usage to the billing system. - - This method: - 1. Gets the mandate context from services - 2. Records the usage cost with 50% markup via BillingService - - Args: - workflow: ChatWorkflow object - aiResponse: AI call response with cost information - process: Process identifier for the AI call - """ - try: - # Check if we have mandate context - mandateId = getattr(self.services, 'mandateId', None) - if not mandateId: - logger.debug("No mandate context, skipping billing recording") - return - - # Check if there's a cost to record - priceCHF = getattr(aiResponse, 'priceCHF', 0.0) - if not priceCHF or priceCHF <= 0: - return - - # Get provider from AiCallResponse (set from model.connectorType) - aicoreProvider = getattr(aiResponse, 'provider', None) or 'unknown' - - # Get feature context if available - featureInstanceId = getattr(self.services, 'featureInstanceId', None) - featureCode = getattr(self.services, 'featureCode', None) - - # Import and use BillingService - from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService - - billingService = getBillingService( - self.user, - mandateId, - featureInstanceId=featureInstanceId, - featureCode=featureCode - ) - - # Record the usage (includes 50% markup) - billingService.recordUsage( - priceCHF=priceCHF, - workflowId=workflow.id, - aicoreProvider=aicoreProvider, - description=f"AI Usage: {process}" - ) - - logger.debug(f"Recorded billing usage: {priceCHF} CHF for {process} (provider: {aicoreProvider})") - - except Exception as e: - # Don't fail the main operation if billing recording fails - logger.warning(f"Failed to record billing usage (non-critical): {e}") def updateMessage(self, messageId: str, messageData: Dict[str, Any]): """Update message by delegating to the chat interface""" diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index 172fc977..06d36dae 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -24,7 +24,7 @@ from .subAutomationUtils import parseScheduleToCron, planToPrompt, replacePlaceh logger = logging.getLogger(__name__) -async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow: +async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None, featureCode: Optional[str] = None) -> ChatWorkflow: """ Starts a new chat or continues an existing one, then launches processing asynchronously. @@ -32,12 +32,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode currentUser: Current user userInput: User input request workflowId: Optional workflow ID to continue existing workflow - workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution - mandateId: Mandate ID from request context (required for proper data isolation) - featureInstanceId: Feature instance ID for context - - Example usage for Dynamic mode: - workflow = await chatStart(currentUser, userInput, workflowMode=WorkflowModeEnum.WORKFLOW_DYNAMIC, mandateId=mandateId) + workflowMode: Workflow mode (Dynamic, Automation, etc.) + mandateId: Mandate ID (required for billing) + featureInstanceId: Feature instance ID (required for billing) + featureCode: Feature code (e.g., 'chatplayground', 'automation') """ try: services = getServices(currentUser, mandateId=mandateId) @@ -47,10 +45,11 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode services.allowedProviders = userInput.allowedProviders logger.info(f"AI provider filter active: {userInput.allowedProviders}") - # Store feature instance ID in services context + # Store feature context in services (for billing and RBAC) if featureInstanceId: services.featureInstanceId = featureInstanceId - services.featureCode = 'chatplayground' + if featureCode: + services.featureCode = featureCode workflowManager = WorkflowManager(services) workflow = await workflowManager.workflowStart(userInput, workflowMode, workflowId) @@ -105,8 +104,11 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: services.allowedProviders = automation.allowedProviders logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}") - # Store feature context for billing + # Context comes EXCLUSIVELY from the automation definition + services.mandateId = str(automation.mandateId) + services.featureInstanceId = str(automation.featureInstanceId) services.featureCode = 'automation' + featureInstanceId = services.featureInstanceId # 2. Replace placeholders in template to generate plan template = automation.template or "" @@ -159,11 +161,16 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: executionLog["messages"].append("Starting workflow execution") # 5. Start workflow using chatStart + # Pass mandateId, featureInstanceId, and featureCode from original services context + # so billing is recorded correctly with full feature context workflow = await chatStart( currentUser=creatorUser, userInput=userInput, workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION, - workflowId=None + workflowId=None, + mandateId=services.mandateId, + featureInstanceId=featureInstanceId, + featureCode='automation' ) executionLog["workflowId"] = workflow.id From 1f3746aef52ae7c2fbd63201b4e2b9327c7f1f83 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Mon, 9 Feb 2026 12:49:35 +0100 Subject: [PATCH 15/18] streamlined bootstrap and initial config --- app.py | 9 + modules/aicore/aicoreModelSelector.py | 15 +- modules/aicore/aicorePluginAnthropic.py | 62 ++- modules/aicore/aicorePluginOpenai.py | 88 +-- modules/aicore/aicorePluginPerplexity.py | 22 +- modules/datamodels/datamodelAi.py | 1 + modules/datamodels/datamodelUam.py | 6 + .../automation/interfaceFeatureAutomation.py | 2 - modules/features/automation/mainAutomation.py | 24 +- .../automation/routeFeatureAutomation.py | 3 +- modules/features/chatbot/__init__.py | 8 +- .../chatbot/interfaceFeatureChatbot.py | 3 - .../features/chatbot/routeFeatureChatbot.py | 3 - .../chatplayground/mainChatplayground.py | 30 +- .../interfaceFeatureNeutralizer.py | 1 - ...inNeutralizer.py => mainNeutralization.py} | 0 ...ePlayground.py => neutralizePlayground.py} | 0 .../neutralization/routeFeatureNeutralizer.py | 2 +- .../realEstate/interfaceFeatureRealEstate.py | 5 - .../trustee/interfaceFeatureTrustee.py | 1 - modules/features/trustee/mainTrustee.py | 8 +- modules/interfaces/interfaceBootstrap.py | 512 +++++++++++++++--- modules/interfaces/interfaceDbApp.py | 180 +++--- modules/interfaces/interfaceDbBilling.py | 303 +++++++---- modules/interfaces/interfaceDbChat.py | 3 - modules/interfaces/interfaceDbManagement.py | 3 - modules/routes/routeBilling.py | 17 +- modules/routes/routeInvitations.py | 16 +- modules/security/rbac.py | 31 +- .../serviceBilling/mainServiceBilling.py | 4 +- modules/shared/eventManagement.py | 2 +- modules/system/mainSystem.py | 47 ++ modules/system/registry.py | 21 +- modules/workflows/workflowManager.py | 251 ++++----- 34 files changed, 1103 insertions(+), 580 deletions(-) rename modules/features/neutralization/{mainNeutralizer.py => mainNeutralization.py} (100%) rename modules/features/neutralization/{mainNeutralizePlayground.py => neutralizePlayground.py} (100%) diff --git a/app.py b/app.py index df7f9306..32eb31f6 100644 --- a/app.py +++ b/app.py @@ -286,6 +286,15 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") async def lifespan(app: FastAPI): logger.info("Application is starting up") + # --- Register RBAC catalog for features (moved here from loadFeatureRouters for single-pass loading) --- + try: + from modules.security.rbacCatalog import getCatalogService + from modules.system.registry import registerAllFeaturesInCatalog + catalogService = getCatalogService() + registerAllFeaturesInCatalog(catalogService) + except Exception as e: + logger.warning(f"Could not register feature RBAC catalog: {e}") + # Get event user for feature lifecycle (system-level user for background operations) rootInterface = getRootInterface() eventUser = rootInterface.getUserByUsername("event") diff --git a/modules/aicore/aicoreModelSelector.py b/modules/aicore/aicoreModelSelector.py index 8bebb2d7..4724356f 100644 --- a/modules/aicore/aicoreModelSelector.py +++ b/modules/aicore/aicoreModelSelector.py @@ -73,12 +73,14 @@ class ModelSelector: contextSize = len(context.encode("utf-8")) totalSize = promptSize + contextSize # Convert bytes to approximate tokens - # Conservative estimate: 1 token ≈ 2 bytes (for safety margin) + # Balanced estimate: 1 token ≈ 3 bytes # Note: Actual tokenization varies by content type and model # - English text: ~4 bytes/token - # - Structured data/JSON: ~2-3 bytes/token + # - German/European text: ~3.5 bytes/token + # - Structured data/JSON: ~2.5-3 bytes/token # - Base64/encoded data: ~1.5-2 bytes/token - bytesPerToken = 2 # Conservative estimate for mixed content + # Using 3 as balanced estimate (previously 2 which overestimated by ~2x) + bytesPerToken = 3 # Balanced estimate for mixed content promptTokens = promptSize / bytesPerToken contextTokens = contextSize / bytesPerToken totalTokens = totalSize / bytesPerToken @@ -98,9 +100,16 @@ class ModelSelector: logger.debug(f"Models with {options.operationType.value}: {[m.name for m in operationFiltered]}") # Step 2: Filter by prompt size (MUST be <= 80% of context size) + # AND by maxInputTokensPerRequest (provider rate limit / TPM) # Note: contextLength is in tokens, so we need to compare tokens with tokens promptFiltered = [] for model in operationFiltered: + # Check provider rate limit first (maxInputTokensPerRequest) + maxRequestTokens = getattr(model, 'maxInputTokensPerRequest', None) + if maxRequestTokens and maxRequestTokens > 0 and totalTokens > maxRequestTokens: + logger.debug(f"Model {model.name} filtered out: totalTokens={totalTokens:.0f} > maxInputTokensPerRequest={maxRequestTokens} (provider rate limit)") + continue + if model.contextLength == 0: # No context length limit - always pass promptFiltered.append(model) diff --git a/modules/aicore/aicorePluginAnthropic.py b/modules/aicore/aicorePluginAnthropic.py index eeea9a07..5809a203 100644 --- a/modules/aicore/aicorePluginAnthropic.py +++ b/modules/aicore/aicorePluginAnthropic.py @@ -46,7 +46,6 @@ class AiAnthropic(BaseConnectorAi): return "anthropic" def getModels(self) -> List[AiModel]: - # return [] # TODO: DEBUG TO TURN ON AFTER TESTING # Get all available Anthropic models. return [ AiModel( @@ -57,11 +56,10 @@ class AiAnthropic(BaseConnectorAi): temperature=0.2, maxTokens=8192, contextLength=200000, - costPer1kTokensInput=0.015, - costPer1kTokensOutput=0.075, + costPer1kTokensInput=0.003, # $3/M tokens (updated 2026-02) + costPer1kTokensOutput=0.015, # $15/M tokens (updated 2026-02) speedRating=6, # Slower due to high-quality processing qualityRating=10, # Best quality available - # capabilities removed (not used in business logic) functionCall=self.callAiBasic, priority=PriorityEnum.QUALITY, processingMode=ProcessingModeEnum.DETAILED, @@ -72,7 +70,55 @@ class AiAnthropic(BaseConnectorAi): (OperationTypeEnum.DATA_EXTRACT, 8) ), version="claude-sonnet-4-5-20250929", - calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015 + ), + AiModel( + name="claude-haiku-4-5-20251001", + displayName="Anthropic Claude Haiku 4.5", + connectorType="anthropic", + apiUrl="https://api.anthropic.com/v1/messages", + temperature=0.2, + maxTokens=8192, + contextLength=200000, + costPer1kTokensInput=0.001, # $1/M tokens (updated 2026-02) + costPer1kTokensOutput=0.005, # $5/M tokens (updated 2026-02) + speedRating=9, # Very fast, lightweight model + qualityRating=8, # Good quality, cost-efficient + functionCall=self.callAiBasic, + priority=PriorityEnum.SPEED, + processingMode=ProcessingModeEnum.BASIC, + operationTypes=createOperationTypeRatings( + (OperationTypeEnum.PLAN, 8), + (OperationTypeEnum.DATA_ANALYSE, 8), + (OperationTypeEnum.DATA_GENERATE, 8), + (OperationTypeEnum.DATA_EXTRACT, 7) + ), + version="claude-haiku-4-5-20251001", + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.005 + ), + AiModel( + name="claude-opus-4-6", + displayName="Anthropic Claude Opus 4.6", + connectorType="anthropic", + apiUrl="https://api.anthropic.com/v1/messages", + temperature=0.2, + maxTokens=8192, + contextLength=200000, + costPer1kTokensInput=0.005, # $5/M tokens (updated 2026-02) + costPer1kTokensOutput=0.025, # $25/M tokens (updated 2026-02) + speedRating=5, # Moderate latency, most capable + qualityRating=10, # Top-tier intelligence + functionCall=self.callAiBasic, + priority=PriorityEnum.QUALITY, + processingMode=ProcessingModeEnum.DETAILED, + operationTypes=createOperationTypeRatings( + (OperationTypeEnum.PLAN, 10), + (OperationTypeEnum.DATA_ANALYSE, 10), + (OperationTypeEnum.DATA_GENERATE, 10), + (OperationTypeEnum.DATA_EXTRACT, 9) + ), + version="claude-opus-4-6", + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.025 ), AiModel( name="claude-sonnet-4-5-20250929", @@ -82,8 +128,8 @@ class AiAnthropic(BaseConnectorAi): temperature=0.2, maxTokens=8192, contextLength=200000, - costPer1kTokensInput=0.015, - costPer1kTokensOutput=0.075, + costPer1kTokensInput=0.003, # $3/M tokens (updated 2026-02) + costPer1kTokensOutput=0.015, # $15/M tokens (updated 2026-02) speedRating=6, qualityRating=10, functionCall=self.callAiImage, @@ -93,7 +139,7 @@ class AiAnthropic(BaseConnectorAi): (OperationTypeEnum.IMAGE_ANALYSE, 10) ), version="claude-sonnet-4-5-20250929", - calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015 ) ] diff --git a/modules/aicore/aicorePluginOpenai.py b/modules/aicore/aicorePluginOpenai.py index 931ece10..5465858c 100644 --- a/modules/aicore/aicorePluginOpenai.py +++ b/modules/aicore/aicorePluginOpenai.py @@ -6,7 +6,7 @@ from typing import List from fastapi import HTTPException from modules.shared.configuration import APP_CONFIG from .aicoreBase import BaseConnectorAi -from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings +from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings, AiCallPromptImage # Configure logger logger = logging.getLogger(__name__) @@ -15,6 +15,10 @@ class ContextLengthExceededException(Exception): """Exception raised when the context length exceeds the model's limit""" pass +class RateLimitExceededException(Exception): + """Exception raised when the provider's rate limit (TPM) is exceeded""" + pass + def loadConfigData(): """Load configuration data for OpenAI connector""" return { @@ -57,11 +61,11 @@ class AiOpenai(BaseConnectorAi): temperature=0.2, maxTokens=16384, contextLength=128000, - costPer1kTokensInput=0.03, - costPer1kTokensOutput=0.06, + maxInputTokensPerRequest=25000, # OpenAI org TPM limit is 30K, keep 5K buffer + costPer1kTokensInput=0.0025, # $2.50/M tokens (updated 2026-02) + costPer1kTokensOutput=0.01, # $10.00/M tokens (updated 2026-02) speedRating=8, # Good speed for complex tasks qualityRating=10, # High quality - # capabilities removed (not used in business logic) functionCall=self.callAiBasic, priority=PriorityEnum.BALANCED, processingMode=ProcessingModeEnum.ADVANCED, @@ -72,43 +76,44 @@ class AiOpenai(BaseConnectorAi): (OperationTypeEnum.DATA_EXTRACT, 7) ), version="gpt-4o", - calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01 ), AiModel( - name="gpt-3.5-turbo", - displayName="OpenAI GPT-3.5 Turbo", - connectorType="openai", - apiUrl="https://api.openai.com/v1/chat/completions", - temperature=0.2, - maxTokens=4096, - contextLength=16000, - costPer1kTokensInput=0.0015, - costPer1kTokensOutput=0.002, - speedRating=9, # Very fast - qualityRating=7, # Good but not premium - # capabilities removed (not used in business logic) - functionCall=self.callAiBasic, - priority=PriorityEnum.SPEED, - processingMode=ProcessingModeEnum.BASIC, - operationTypes=createOperationTypeRatings( - (OperationTypeEnum.PLAN, 7), - (OperationTypeEnum.DATA_ANALYSE, 8), - (OperationTypeEnum.DATA_GENERATE, 8) - # Note: GPT-3.5-turbo does NOT support vision/image operations - ), - version="gpt-3.5-turbo", - calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0015 + (bytesReceived / 4 / 1000) * 0.002 - ), - AiModel( - name="gpt-4o", - displayName="OpenAI GPT-4o Instance Vision", + name="gpt-4o-mini", + displayName="OpenAI GPT-4o Mini", connectorType="openai", apiUrl="https://api.openai.com/v1/chat/completions", temperature=0.2, maxTokens=16384, contextLength=128000, - costPer1kTokensInput=0.03, - costPer1kTokensOutput=0.06, + maxInputTokensPerRequest=25000, # OpenAI org TPM limit, keep buffer + costPer1kTokensInput=0.00015, # $0.15/M tokens (updated 2026-02) + costPer1kTokensOutput=0.0006, # $0.60/M tokens (updated 2026-02) + speedRating=9, # Very fast + qualityRating=8, # Good quality, replaces gpt-3.5-turbo + functionCall=self.callAiBasic, + priority=PriorityEnum.SPEED, + processingMode=ProcessingModeEnum.BASIC, + operationTypes=createOperationTypeRatings( + (OperationTypeEnum.PLAN, 8), + (OperationTypeEnum.DATA_ANALYSE, 8), + (OperationTypeEnum.DATA_GENERATE, 9), + (OperationTypeEnum.DATA_EXTRACT, 7) + ), + version="gpt-4o-mini", + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00015 + (bytesReceived / 4 / 1000) * 0.0006 + ), + AiModel( + name="gpt-4o", + displayName="OpenAI GPT-4o Vision", + connectorType="openai", + apiUrl="https://api.openai.com/v1/chat/completions", + temperature=0.2, + maxTokens=16384, + contextLength=128000, + maxInputTokensPerRequest=25000, # OpenAI org TPM limit is 30K, keep 5K buffer + costPer1kTokensInput=0.0025, # $2.50/M tokens (updated 2026-02) + costPer1kTokensOutput=0.01, # $10.00/M tokens (updated 2026-02) speedRating=6, # Slower for vision tasks qualityRating=9, # High quality vision functionCall=self.callAiImage, @@ -118,7 +123,7 @@ class AiOpenai(BaseConnectorAi): (OperationTypeEnum.IMAGE_ANALYSE, 9) ), version="gpt-4o", - calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0025 + (bytesReceived / 4 / 1000) * 0.01 ), AiModel( name="dall-e-3", @@ -183,6 +188,19 @@ class AiOpenai(BaseConnectorAi): error_message = f"OpenAI API error: {response.status_code} - {response.text}" logger.error(error_message) + # Check for rate limit exceeded (429 TPM) + if response.status_code == 429: + try: + error_data = response.json() + error_msg = error_data.get("error", {}).get("message", "Rate limit exceeded") + raise RateLimitExceededException( + f"Rate limit exceeded for {model.name}: {error_msg}" + ) + except (ValueError, KeyError): + raise RateLimitExceededException( + f"Rate limit exceeded for {model.name}" + ) + # Check for context length exceeded error if response.status_code == 400: try: diff --git a/modules/aicore/aicorePluginPerplexity.py b/modules/aicore/aicorePluginPerplexity.py index 7cb5e928..dd13deb1 100644 --- a/modules/aicore/aicorePluginPerplexity.py +++ b/modules/aicore/aicorePluginPerplexity.py @@ -59,13 +59,12 @@ class AiPerplexity(BaseConnectorAi): connectorType="perplexity", apiUrl="https://api.perplexity.ai/chat/completions", temperature=0.2, - maxTokens=24000, # Increased for detailed web crawl responses (Perplexity supports up to 25k) - contextLength=32000, - costPer1kTokensInput=0.005, - costPer1kTokensOutput=0.005, + maxTokens=24000, + contextLength=127000, # 127K context window (updated 2026-02) + costPer1kTokensInput=0.001, # $1/M tokens (updated 2026-02) + costPer1kTokensOutput=0.001, # $1/M tokens (updated 2026-02) speedRating=8, qualityRating=8, - # capabilities removed (not used in business logic) functionCall=self._routeWebOperation, priority=PriorityEnum.BALANCED, processingMode=ProcessingModeEnum.ADVANCED, @@ -74,7 +73,7 @@ class AiPerplexity(BaseConnectorAi): (OperationTypeEnum.WEB_CRAWL, 7) ), version="sonar", - calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.005 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.001 + (bytesReceived / 4 / 1000) * 0.001 ), AiModel( name="sonar-pro", @@ -82,13 +81,12 @@ class AiPerplexity(BaseConnectorAi): connectorType="perplexity", apiUrl="https://api.perplexity.ai/chat/completions", temperature=0.2, - maxTokens=24000, # Increased for detailed web crawl responses (Perplexity supports up to 25k) - contextLength=32000, - costPer1kTokensInput=0.01, - costPer1kTokensOutput=0.01, + maxTokens=24000, + contextLength=200000, # 200K context window (updated 2026-02) + costPer1kTokensInput=0.003, # $3/M tokens (updated 2026-02) + costPer1kTokensOutput=0.015, # $15/M tokens (updated 2026-02) speedRating=6, # Slower due to AI analysis qualityRating=9, # Best AI analysis quality - # capabilities removed (not used in business logic) functionCall=self._routeWebOperation, priority=PriorityEnum.QUALITY, processingMode=ProcessingModeEnum.DETAILED, @@ -97,7 +95,7 @@ class AiPerplexity(BaseConnectorAi): (OperationTypeEnum.WEB_CRAWL, 8) ), version="sonar-pro", - calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.01 + (bytesReceived / 4 / 1000) * 0.01 + calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.003 + (bytesReceived / 4 / 1000) * 0.015 ) ] diff --git a/modules/datamodels/datamodelAi.py b/modules/datamodels/datamodelAi.py index 44eac445..5b259a02 100644 --- a/modules/datamodels/datamodelAi.py +++ b/modules/datamodels/datamodelAi.py @@ -87,6 +87,7 @@ class AiModel(BaseModel): # Token and context limits maxTokens: int = Field(description="Maximum tokens this model can generate") contextLength: int = Field(description="Maximum context length this model can handle") + maxInputTokensPerRequest: Optional[int] = Field(default=None, description="Max input tokens per single request (provider rate limit / TPM). If set, model selector filters requests exceeding this limit.") # Cost information costPer1kTokensInput: float = Field(default=0.0, description="Cost per 1000 input tokens") diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py index 155047a2..d8e7906a 100644 --- a/modules/datamodels/datamodelUam.py +++ b/modules/datamodels/datamodelUam.py @@ -83,6 +83,11 @@ class Mandate(BaseModel): description="Indicates whether the mandate is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False} ) + isSystem: bool = Field( + default=False, + description="Whether this is a system mandate (e.g. root mandate). Cannot be deleted.", + json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": True, "frontend_required": False} + ) registerModelLabels( @@ -93,6 +98,7 @@ registerModelLabels( "name": {"en": "Name", "de": "Name", "fr": "Nom"}, "description": {"en": "Description", "de": "Beschreibung", "fr": "Description"}, "enabled": {"en": "Enabled", "de": "Aktiviert", "fr": "Activé"}, + "isSystem": {"en": "System Mandate", "de": "System-Mandant", "fr": "Mandat système"}, }, ) diff --git a/modules/features/automation/interfaceFeatureAutomation.py b/modules/features/automation/interfaceFeatureAutomation.py index 770d5eb0..e99f7683 100644 --- a/modules/features/automation/interfaceFeatureAutomation.py +++ b/modules/features/automation/interfaceFeatureAutomation.py @@ -69,8 +69,6 @@ class AutomationObjects: userId=self.userId, ) - # Initialize database system - self.db.initDbSystem() logger.debug(f"Automation database initialized for user {self.userId}") def setUserContext(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None): diff --git a/modules/features/automation/mainAutomation.py b/modules/features/automation/mainAutomation.py index 924f5bc9..bbf258cc 100644 --- a/modules/features/automation/mainAutomation.py +++ b/modules/features/automation/mainAutomation.py @@ -59,24 +59,7 @@ RESOURCE_OBJECTS = [ ] # Template roles for this feature -# IMPORTANT: "viewer" role is required for automatic user assignment! TEMPLATE_ROLES = [ - { - "roleLabel": "viewer", - "description": { - "en": "Automation Viewer - View automations and execution results", - "de": "Automatisierungs-Betrachter - Automatisierungen und Ausführungsergebnisse einsehen", - "fr": "Visualiseur automatisation - Consulter les automatisations et résultats" - }, - "accessRules": [ - # UI access to all views - {"context": "UI", "item": "ui.feature.automation.definitions", "view": True}, - {"context": "UI", "item": "ui.feature.automation.templates", "view": True}, - {"context": "UI", "item": "ui.feature.automation.logs", "view": True}, - # Read-only DATA access - {"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "n"}, - ] - }, { "roleLabel": "automation-admin", "description": { @@ -115,7 +98,7 @@ TEMPLATE_ROLES = [ "fr": "Visualiseur automatisation - Consulter les automatisations et résultats" }, "accessRules": [ - # UI access to view only - vollqualifizierte ObjectKeys + # UI access to view only {"context": "UI", "item": "ui.feature.automation.definitions", "view": True}, {"context": "UI", "item": "ui.feature.automation.logs", "view": True}, # Read-only DATA access (my level) @@ -130,7 +113,8 @@ def getFeatureDefinition() -> Dict[str, Any]: return { "code": FEATURE_CODE, "label": FEATURE_LABEL, - "icon": FEATURE_ICON + "icon": FEATURE_ICON, + "autoCreateInstance": True, # Automatically create instance in root mandate during bootstrap } @@ -215,8 +199,6 @@ def _syncTemplateRolesToDb() -> int: if roleLabel in existingRoleLabels: roleId = existingRoleLabels[roleLabel] - logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}") - # Ensure AccessRules exist for this role _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) else: diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py index 8a5fee1d..f7c5feda 100644 --- a/modules/features/automation/routeFeatureAutomation.py +++ b/modules/features/automation/routeFeatureAutomation.py @@ -19,8 +19,6 @@ from modules.features.automation.datamodelFeatureAutomation import AutomationDef from modules.datamodels.datamodelChat import ChatWorkflow from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict from modules.shared.attributeUtils import getModelAttributeDefinitions -from modules.workflows.automation import executeAutomation - # Configure logger logger = logging.getLogger(__name__) @@ -371,6 +369,7 @@ async def execute_automation_route( if context.featureInstanceId: services.featureInstanceId = str(context.featureInstanceId) services.featureCode = 'automation' + from modules.workflows.automation import executeAutomation workflow = await executeAutomation(automationId, services) return workflow except HTTPException: diff --git a/modules/features/chatbot/__init__.py b/modules/features/chatbot/__init__.py index 46017d53..30b57e1f 100644 --- a/modules/features/chatbot/__init__.py +++ b/modules/features/chatbot/__init__.py @@ -2,8 +2,14 @@ # All rights reserved. """ Chatbot feature - LangGraph-based chatbot implementation. +Lazy-loaded to avoid importing langgraph/langchain at boot time. """ -from .service import chatProcess + +async def chatProcess(*args, **kwargs): + """Lazy wrapper - imports the real chatProcess on first call to defer langgraph loading.""" + from .service import chatProcess as _chatProcess + return await _chatProcess(*args, **kwargs) + __all__ = ['chatProcess'] diff --git a/modules/features/chatbot/interfaceFeatureChatbot.py b/modules/features/chatbot/interfaceFeatureChatbot.py index 68474898..4d77f633 100644 --- a/modules/features/chatbot/interfaceFeatureChatbot.py +++ b/modules/features/chatbot/interfaceFeatureChatbot.py @@ -329,9 +329,6 @@ class ChatObjects: userId=self.userId ) - # Initialize database system - self.db.initDbSystem() - logger.info("Database initialized successfully") except Exception as e: logger.error(f"Failed to initialize database: {str(e)}") diff --git a/modules/features/chatbot/routeFeatureChatbot.py b/modules/features/chatbot/routeFeatureChatbot.py index d8b4dd70..ade51e9f 100644 --- a/modules/features/chatbot/routeFeatureChatbot.py +++ b/modules/features/chatbot/routeFeatureChatbot.py @@ -32,9 +32,6 @@ from modules.datamodels.datamodelPagination import PaginationParams, PaginatedRe from modules.features.chatbot import chatProcess from modules.features.chatbot.streaming.events import get_event_manager -# Import workflow control functions -from modules.workflows.automation import chatStop - # Configure logger logger = logging.getLogger(__name__) diff --git a/modules/features/chatplayground/mainChatplayground.py b/modules/features/chatplayground/mainChatplayground.py index 268ee467..085d93e4 100644 --- a/modules/features/chatplayground/mainChatplayground.py +++ b/modules/features/chatplayground/mainChatplayground.py @@ -54,15 +54,30 @@ TEMPLATE_ROLES = [ { "roleLabel": "viewer", "description": { - "en": "Chat Playground Viewer - View and use chat playground", - "de": "Chat Playground Betrachter - Chat Playground ansehen und nutzen", - "fr": "Visualiseur Chat Playground - Consulter et utiliser le chat playground" + "en": "Chat Playground Viewer - View chat playground (read-only)", + "de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)", + "fr": "Visualiseur Chat Playground - Consulter le chat playground (lecture seule)" }, "accessRules": [ - # UI access to all views + # UI: only playground view, NO workflows + {"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True}, + # RESOURCE: NO access (viewer cannot start/stop/access chat data) + # DATA access (own records, read-only) + {"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"}, + ] + }, + { + "roleLabel": "user", + "description": { + "en": "Chat Playground User - Use chat playground and workflows", + "de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen", + "fr": "Utilisateur Chat Playground - Utiliser le chat playground et les workflows" + }, + "accessRules": [ + # UI: full access to all views {"context": "UI", "item": "ui.feature.chatplayground.playground", "view": True}, {"context": "UI", "item": "ui.feature.chatplayground.workflows", "view": True}, - # Resource access + # Resource access: can start/stop workflows and access chat data {"context": "RESOURCE", "item": "resource.feature.chatplayground.start", "view": True}, {"context": "RESOURCE", "item": "resource.feature.chatplayground.stop", "view": True}, {"context": "RESOURCE", "item": "resource.feature.chatplayground.chatData", "view": True}, @@ -94,7 +109,8 @@ def getFeatureDefinition() -> Dict[str, Any]: return { "code": FEATURE_CODE, "label": FEATURE_LABEL, - "icon": FEATURE_ICON + "icon": FEATURE_ICON, + "autoCreateInstance": True, # Automatically create instance in root mandate during bootstrap } @@ -179,8 +195,6 @@ def _syncTemplateRolesToDb() -> int: if roleLabel in existingRoleLabels: roleId = existingRoleLabels[roleLabel] - logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}") - # Ensure AccessRules exist for this role _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) else: diff --git a/modules/features/neutralization/interfaceFeatureNeutralizer.py b/modules/features/neutralization/interfaceFeatureNeutralizer.py index 98b85fdb..32d9493b 100644 --- a/modules/features/neutralization/interfaceFeatureNeutralizer.py +++ b/modules/features/neutralization/interfaceFeatureNeutralizer.py @@ -66,7 +66,6 @@ class InterfaceFeatureNeutralizer: dbPort=dbPort, userId=self.userId, ) - self.db.initDbSystem() logger.debug("Neutralizer database initialized successfully") except Exception as e: logger.error(f"Error initializing Neutralizer database: {str(e)}") diff --git a/modules/features/neutralization/mainNeutralizer.py b/modules/features/neutralization/mainNeutralization.py similarity index 100% rename from modules/features/neutralization/mainNeutralizer.py rename to modules/features/neutralization/mainNeutralization.py diff --git a/modules/features/neutralization/mainNeutralizePlayground.py b/modules/features/neutralization/neutralizePlayground.py similarity index 100% rename from modules/features/neutralization/mainNeutralizePlayground.py rename to modules/features/neutralization/neutralizePlayground.py diff --git a/modules/features/neutralization/routeFeatureNeutralizer.py b/modules/features/neutralization/routeFeatureNeutralizer.py index 33b9a00d..d1590d28 100644 --- a/modules/features/neutralization/routeFeatureNeutralizer.py +++ b/modules/features/neutralization/routeFeatureNeutralizer.py @@ -9,7 +9,7 @@ from modules.auth import limiter, getRequestContext, RequestContext # Import interfaces from .datamodelFeatureNeutralizer import DataNeutraliserConfig, DataNeutralizerAttributes -from .mainNeutralizePlayground import NeutralizationPlayground +from .neutralizePlayground import NeutralizationPlayground # Configure logger logger = logging.getLogger(__name__) diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py index 40a85c7e..86374a2c 100644 --- a/modules/features/realEstate/interfaceFeatureRealEstate.py +++ b/modules/features/realEstate/interfaceFeatureRealEstate.py @@ -85,11 +85,6 @@ class RealEstateObjects: userId=self.userId if self.userId else None, ) - # Initialize database system (creates database and system table if needed) - # Note: This is also called in DatabaseConnector.__init__, but we call it explicitly - # for consistency with other interfaces and to ensure proper initialization - self.db.initDbSystem() - # Ensure all supporting tables are created (Land, Kanton, Gemeinde, Dokument) # These tables are needed for foreign key relationships self._ensureSupportingTablesExist() diff --git a/modules/features/trustee/interfaceFeatureTrustee.py b/modules/features/trustee/interfaceFeatureTrustee.py index bb6695d3..8710d148 100644 --- a/modules/features/trustee/interfaceFeatureTrustee.py +++ b/modules/features/trustee/interfaceFeatureTrustee.py @@ -155,7 +155,6 @@ class TrusteeObjects: userId=self.userId, ) - self.db.initDbSystem() logger.info(f"Trustee database initialized successfully for user {self.userId}") except Exception as e: logger.error(f"Failed to initialize Trustee database: {str(e)}") diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py index b917f9ad..8d7e1243 100644 --- a/modules/features/trustee/mainTrustee.py +++ b/modules/features/trustee/mainTrustee.py @@ -144,12 +144,11 @@ TEMPLATE_ROLES = [ "fr": "Comptable fiduciaire - Gérer les données comptables et financières" }, "accessRules": [ - # UI access to main views (not admin views) - vollqualifizierte ObjectKeys + # UI access to main views (not admin views, not expense-import) - vollqualifizierte ObjectKeys {"context": "UI", "item": "ui.feature.trustee.dashboard", "view": True}, {"context": "UI", "item": "ui.feature.trustee.positions", "view": True}, {"context": "UI", "item": "ui.feature.trustee.documents", "view": True}, {"context": "UI", "item": "ui.feature.trustee.position-documents", "view": True}, - {"context": "UI", "item": "ui.feature.trustee.expense-import", "view": True}, # Group-level DATA access {"context": "DATA", "item": None, "view": True, "read": "g", "create": "g", "update": "g", "delete": "g"}, ] @@ -162,11 +161,12 @@ TEMPLATE_ROLES = [ "fr": "Client fiduciaire - Consulter ses propres données comptables et documents" }, "accessRules": [ - # UI access to main views only (read-only focus) - vollqualifizierte ObjectKeys + # UI access to main views + expense-import - vollqualifizierte ObjectKeys {"context": "UI", "item": "ui.feature.trustee.dashboard", "view": True}, {"context": "UI", "item": "ui.feature.trustee.positions", "view": True}, {"context": "UI", "item": "ui.feature.trustee.documents", "view": True}, {"context": "UI", "item": "ui.feature.trustee.position-documents", "view": True}, + {"context": "UI", "item": "ui.feature.trustee.expense-import", "view": True}, # Own records only (MY level) - explizite Regeln pro Tabelle {"context": "DATA", "item": "data.feature.trustee.TrusteePosition", "view": True, "read": "m", "create": "m", "update": "m", "delete": "n"}, {"context": "DATA", "item": "data.feature.trustee.TrusteeDocument", "view": True, "read": "m", "create": "m", "update": "m", "delete": "n"}, @@ -279,8 +279,6 @@ def _syncTemplateRolesToDb() -> int: if roleLabel in existingRoleLabels: roleId = existingRoleLabels[roleLabel] - logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}") - # Ensure AccessRules exist for this role _ensureAccessRulesForRole(rootInterface, roleId, roleTemplate.get("accessRules", [])) else: diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index 99cca1a2..f750565d 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -51,12 +51,16 @@ def initBootstrap(db: DatabaseConnector) -> None: # Initialize root mandate mandateId = initRootMandate(db) - # Initialize roles FIRST (needed for AccessRules) + # Initialize system role TEMPLATES (mandateId=None, isSystemRole=True) initRoles(db) - # Initialize RBAC rules (uses roleIds from roles) + # Initialize RBAC rules for template roles initRbacRules(db) + # Copy system template roles to ALL mandates as mandate-instance roles + # This also serves as migration for existing mandates that don't have instance roles yet + _ensureAllMandatesHaveSystemRoles(db) + # Initialize admin user adminUserId = initAdminUser(db, mandateId) @@ -64,6 +68,7 @@ def initBootstrap(db: DatabaseConnector) -> None: eventUserId = initEventUser(db, mandateId) # Assign initial user memberships (via UserMandate + UserMandateRole) + # Uses mandate-instance roles (not template roles) if adminUserId and eventUserId and mandateId: assignInitialUserMemberships(db, mandateId, adminUserId, eventUserId) @@ -163,8 +168,8 @@ def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] def initRootMandateFeatures(db: DatabaseConnector, mandateId: str) -> None: """ - Create feature instances for root mandate (chatplayground, automation). - These features are available to all users by default. + Create feature instances for root mandate. + Dynamically discovers all feature modules with autoCreateInstance=True. Args: db: Database connector instance @@ -172,14 +177,29 @@ def initRootMandateFeatures(db: DatabaseConnector, mandateId: str) -> None: """ from modules.datamodels.datamodelFeatures import FeatureInstance from modules.interfaces.interfaceFeatures import getFeatureInterface + from modules.system.registry import loadFeatureMainModules logger.info("Initializing root mandate features") - # Features to create instances for - featuresToCreate = [ - {"code": "chatplayground", "label": "Chat Playground"}, - {"code": "automation", "label": "Automation"}, - ] + # Dynamically discover features with autoCreateInstance=True + featuresToCreate = [] + mainModules = loadFeatureMainModules() + + for featureName, module in mainModules.items(): + if hasattr(module, "getFeatureDefinition"): + try: + featureDef = module.getFeatureDefinition() + if featureDef.get("autoCreateInstance", False): + featureCode = featureDef.get("code", featureName) + featureLabel = featureDef.get("label", {}).get("en", featureName) + featuresToCreate.append({"code": featureCode, "label": featureLabel}) + logger.debug(f"Feature '{featureCode}' marked for auto-creation in root mandate") + except Exception as e: + logger.warning(f"Could not read feature definition for '{featureName}': {e}") + + if not featuresToCreate: + logger.info("No features marked for auto-creation in root mandate") + return featureInterface = getFeatureInterface(db) @@ -225,6 +245,7 @@ def initRootMandateFeatures(db: DatabaseConnector, mandateId: str) -> None: def initRootMandate(db: DatabaseConnector) -> Optional[str]: """ Creates the Root mandate if it doesn't exist. + Root mandate is identified by name='root' AND isSystem=True. Args: db: Database connector instance @@ -232,14 +253,23 @@ def initRootMandate(db: DatabaseConnector) -> Optional[str]: Returns: Mandate ID if created or found, None otherwise """ - existingMandates = db.getRecordset(Mandate) + # Find existing root mandate by name AND isSystem flag + existingMandates = db.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True}) if existingMandates: mandateId = existingMandates[0].get("id") logger.info(f"Root mandate already exists with ID {mandateId}") return mandateId + # Check for legacy root mandates (name="Root" without isSystem flag) and migrate + legacyMandates = db.getRecordset(Mandate, recordFilter={"name": "Root"}) + if legacyMandates: + mandateId = legacyMandates[0].get("id") + logger.info(f"Migrating legacy Root mandate {mandateId}: setting name='root', isSystem=True") + db.recordModify(Mandate, mandateId, {"name": "root", "isSystem": True}) + return mandateId + logger.info("Creating Root mandate") - rootMandate = Mandate(name="Root", enabled=True) + rootMandate = Mandate(name="root", isSystem=True, enabled=True) createdMandate = db.recordCreate(Mandate, rootMandate) mandateId = createdMandate.get("id") logger.info(f"Root mandate created with ID {mandateId}") @@ -383,11 +413,113 @@ def initRoles(db: DatabaseConnector) -> None: logger.warning(f"Error creating role {role.roleLabel}: {e}") else: _roleIdCache[role.roleLabel] = existingRoleLabels[role.roleLabel] - logger.debug(f"Role {role.roleLabel} already exists with ID {existingRoleLabels[role.roleLabel]}") logger.info("Roles initialization completed") +def _ensureAllMandatesHaveSystemRoles(db: DatabaseConnector) -> None: + """ + Ensure all existing mandates have system-instance roles. + Serves as both initial setup and migration for existing mandates. + """ + allMandates = db.getRecordset(Mandate) + if not allMandates: + return + + for mandate in allMandates: + mandateId = mandate.get("id") + copySystemRolesToMandate(db, mandateId) + + +def copySystemRolesToMandate(db: DatabaseConnector, mandateId: str) -> int: + """ + Copy system template roles (mandateId=None, isSystemRole=True) to a mandate + as mandate-instance roles. Also copies all AccessRules for each role. + + This is analogous to how feature template roles are copied to feature instances. + Each mandate gets its own instances of admin/user/viewer with their AccessRules. + + Args: + db: Database connector instance + mandateId: Target mandate ID + + Returns: + Number of roles copied + """ + import uuid as _uuid + + # Find system template roles (global, no mandateId) + templateRoles = db.getRecordset( + Role, + recordFilter={"isSystemRole": True, "mandateId": None} + ) + + if not templateRoles: + logger.debug("No system template roles found to copy") + return 0 + + # Check which roles already exist for this mandate + existingMandateRoles = db.getRecordset( + Role, + recordFilter={"mandateId": mandateId, "featureInstanceId": None} + ) + existingLabels = {r.get("roleLabel") for r in existingMandateRoles} + + # Load all AccessRules for template roles + templateRoleIds = [r.get("id") for r in templateRoles] + rulesByRoleId = {} + for roleId in templateRoleIds: + rules = db.getRecordset(AccessRule, recordFilter={"roleId": roleId}) + rulesByRoleId[roleId] = rules + + copiedCount = 0 + for templateRole in templateRoles: + roleLabel = templateRole.get("roleLabel") + + # Skip if mandate already has this role + if roleLabel in existingLabels: + logger.debug(f"Mandate {mandateId} already has role '{roleLabel}', skipping") + continue + + newRoleId = str(_uuid.uuid4()) + + # Create mandate-instance role + newRole = Role( + id=newRoleId, + roleLabel=roleLabel, + description=templateRole.get("description", {}), + mandateId=mandateId, + featureInstanceId=None, + featureCode=None, + isSystemRole=True # Still a system role, but bound to this mandate + ) + db.recordCreate(Role, newRole.model_dump()) + + # Copy AccessRules + templateRules = rulesByRoleId.get(templateRole.get("id"), []) + for rule in templateRules: + newRule = AccessRule( + id=str(_uuid.uuid4()), + roleId=newRoleId, + context=rule.get("context"), + item=rule.get("item"), + view=rule.get("view", False), + read=rule.get("read"), + create=rule.get("create"), + update=rule.get("update"), + delete=rule.get("delete") + ) + db.recordCreate(AccessRule, newRule.model_dump()) + + copiedCount += 1 + logger.info(f"Copied system role '{roleLabel}' to mandate {mandateId} with {len(templateRules)} AccessRules") + + if copiedCount > 0: + logger.info(f"Copied {copiedCount} system roles to mandate {mandateId}") + + return copiedCount + + def _getRoleId(db: DatabaseConnector, roleLabel: str) -> Optional[str]: """ Get role ID by label, using cache or database lookup. @@ -861,6 +993,117 @@ def _createTableSpecificRules(db: DatabaseConnector) -> None: delete=AccessLevel.NONE, )) + # ------------------------------------------------------------------------- + # Billing Namespace - Billing accounts and transactions + # ------------------------------------------------------------------------- + + # BillingAccount: User sees own accounts (MY), Admin sees all in mandate (GROUP) + # Each user must see all billing accounts assigned to them + if adminId: + tableRules.append(AccessRule( + roleId=adminId, + context=AccessRuleContext.DATA, + item="data.billing.BillingAccount", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if userId: + tableRules.append(AccessRule( + roleId=userId, + context=AccessRuleContext.DATA, + item="data.billing.BillingAccount", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if viewerId: + tableRules.append(AccessRule( + roleId=viewerId, + context=AccessRuleContext.DATA, + item="data.billing.BillingAccount", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # BillingTransaction: User sees own transactions (MY), Admin sees all in mandate (GROUP) + if adminId: + tableRules.append(AccessRule( + roleId=adminId, + context=AccessRuleContext.DATA, + item="data.billing.BillingTransaction", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if userId: + tableRules.append(AccessRule( + roleId=userId, + context=AccessRuleContext.DATA, + item="data.billing.BillingTransaction", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if viewerId: + tableRules.append(AccessRule( + roleId=viewerId, + context=AccessRuleContext.DATA, + item="data.billing.BillingTransaction", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # BillingSettings: Only admin can view mandate settings (read-only) + # SysAdmin (flag) manages settings, roles only read + if adminId: + tableRules.append(AccessRule( + roleId=adminId, + context=AccessRuleContext.DATA, + item="data.billing.BillingSettings", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if userId: + tableRules.append(AccessRule( + roleId=userId, + context=AccessRuleContext.DATA, + item="data.billing.BillingSettings", + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if viewerId: + tableRules.append(AccessRule( + roleId=viewerId, + context=AccessRuleContext.DATA, + item="data.billing.BillingSettings", + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + # Create all table-specific rules for rule in tableRules: db.recordCreate(AccessRule, rule) @@ -992,8 +1235,7 @@ def _ensureUiContextRules(db: DatabaseConnector) -> None: for rule in missingRules: db.recordCreate(AccessRule, rule) logger.info(f"Created {len(missingRules)} missing UI context rules") - else: - logger.debug("All UI context rules already exist") + # All UI context rules already exist (nothing to create) def _ensureDataContextRules(db: DatabaseConnector) -> None: @@ -1034,6 +1276,13 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None: "data.automation.AutomationTemplate", ] + # Billing tables: read-only for all roles, scoped by role level + # Users see their own accounts/transactions (MY), Admins see mandate-wide (GROUP) + billingReadOnlyTables = [ + "data.billing.BillingAccount", + "data.billing.BillingTransaction", + ] + missingRules = [] # MY-level rules for user-owned tables @@ -1077,9 +1326,9 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None: delete=AccessLevel.NONE, )) - # ALL-level rules for admin on system templates + # Admin rules for system templates (read ALL, write GROUP-scoped) for objectKey in tablesNeedingAllRulesForAdmin: - # Admin: ALL-level access (sees all templates) + # Admin: read ALL templates, create/update/delete within GROUP (mandate-scoped) if adminId and (adminId, objectKey) not in existingCombinations: missingRules.append(AccessRule( roleId=adminId, @@ -1087,9 +1336,9 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None: item=objectKey, view=True, read=AccessLevel.ALL, - create=AccessLevel.ALL, - update=AccessLevel.ALL, - delete=AccessLevel.ALL, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP, )) # User: MY-level access @@ -1118,13 +1367,89 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None: delete=AccessLevel.NONE, )) + # Billing read-only rules: Admin=GROUP, User/Viewer=MY (own accounts/transactions) + for objectKey in billingReadOnlyTables: + # Admin: GROUP-level read (sees all accounts in their mandates) + if adminId and (adminId, objectKey) not in existingCombinations: + missingRules.append(AccessRule( + roleId=adminId, + context=AccessRuleContext.DATA, + item=objectKey, + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # User: MY-level read (sees only own billing accounts/transactions) + if userId and (userId, objectKey) not in existingCombinations: + missingRules.append(AccessRule( + roleId=userId, + context=AccessRuleContext.DATA, + item=objectKey, + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # Viewer: MY-level read-only (sees only own billing accounts/transactions) + if viewerId and (viewerId, objectKey) not in existingCombinations: + missingRules.append(AccessRule( + roleId=viewerId, + context=AccessRuleContext.DATA, + item=objectKey, + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # BillingSettings: Admin can view (GROUP), User/Viewer have no access + billingSettingsKey = "data.billing.BillingSettings" + if adminId and (adminId, billingSettingsKey) not in existingCombinations: + missingRules.append(AccessRule( + roleId=adminId, + context=AccessRuleContext.DATA, + item=billingSettingsKey, + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if userId and (userId, billingSettingsKey) not in existingCombinations: + missingRules.append(AccessRule( + roleId=userId, + context=AccessRuleContext.DATA, + item=billingSettingsKey, + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + if viewerId and (viewerId, billingSettingsKey) not in existingCombinations: + missingRules.append(AccessRule( + roleId=viewerId, + context=AccessRuleContext.DATA, + item=billingSettingsKey, + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + # Create missing rules if missingRules: for rule in missingRules: db.recordCreate(AccessRule, rule) logger.info(f"Created {len(missingRules)} missing DATA context rules") - else: - logger.debug("All DATA context rules already exist") + # All DATA context rules already exist (nothing to create) # Update existing AutomationTemplate rules for admin/viewer to ALL access _updateAutomationTemplateRulesToAll(db, adminId, viewerId) @@ -1132,8 +1457,9 @@ def _ensureDataContextRules(db: DatabaseConnector) -> None: def _updateAutomationTemplateRulesToAll(db: DatabaseConnector, adminId: Optional[str], viewerId: Optional[str]) -> None: """ - Update existing AutomationTemplate RBAC rules from MY to ALL for admin and viewer. - This ensures sysadmins can see all templates (including system-seeded ones). + Update existing AutomationTemplate RBAC rules to correct levels. + - Admin: read=ALL, create/update/delete=GROUP (mandate-scoped writes) + - Viewer: read=ALL (read-only) """ if not adminId and not viewerId: return @@ -1155,14 +1481,29 @@ def _updateAutomationTemplateRulesToAll(db: DatabaseConnector, adminId: Optional roleId = rule.get("roleId") currentReadLevel = rule.get("read") - # Update admin and viewer rules from MY to ALL - if roleId in [adminId, viewerId] and currentReadLevel == AccessLevel.MY.value: + if roleId == adminId: + # Admin: read ALL, write GROUP + updates = {} + if currentReadLevel != AccessLevel.ALL.value: + updates["read"] = AccessLevel.ALL.value + currentCreate = rule.get("create") + if currentCreate == AccessLevel.ALL.value: + updates["create"] = AccessLevel.GROUP.value + updates["update"] = AccessLevel.GROUP.value + updates["delete"] = AccessLevel.GROUP.value + if updates: + db.recordModify(AccessRule, ruleId, updates) + updatedCount += 1 + logger.debug(f"Updated AutomationTemplate rule {ruleId} for admin to read=ALL, write=GROUP") + + elif roleId == viewerId and currentReadLevel == AccessLevel.MY.value: + # Viewer: read ALL (read-only) db.recordModify(AccessRule, ruleId, {"read": AccessLevel.ALL.value}) updatedCount += 1 - logger.debug(f"Updated AutomationTemplate rule {ruleId} for role {roleId} to ALL access") + logger.debug(f"Updated AutomationTemplate rule {ruleId} for viewer to read=ALL") if updatedCount > 0: - logger.info(f"Updated {updatedCount} AutomationTemplate RBAC rules to ALL access") + logger.info(f"Updated {updatedCount} AutomationTemplate RBAC rules") def _createResourceContextRules(db: DatabaseConnector) -> None: @@ -1177,8 +1518,8 @@ def _createResourceContextRules(db: DatabaseConnector) -> None: """ resourceRules = [] - # All roles get full resource access by default (no sysadmin - that's a flag) - for roleLabel in ["admin", "user", "viewer"]: + # Admin and User get default resource access; Viewer gets NO resource access + for roleLabel in ["admin", "user"]: roleId = _getRoleId(db, roleLabel) if roleId: resourceRules.append(AccessRule( @@ -1192,6 +1533,8 @@ def _createResourceContextRules(db: DatabaseConnector) -> None: delete=None, )) + # Viewer: no default RESOURCE access (viewer cannot use system resources) + for rule in resourceRules: db.recordCreate(AccessRule, rule) @@ -1204,7 +1547,11 @@ def _createResourceContextRules(db: DatabaseConnector) -> None: def _createAicoreProviderRules(db: DatabaseConnector) -> None: """ Create RBAC rules for AICore providers (resource.aicore.{provider}). - All roles get access to all providers by default. + + Provider access per role: + - admin: all providers allowed + - user: all providers EXCEPT anthropic (view=False) + - viewer: NO provider access (viewer has no RESOURCE permissions) NOTE: Provider list is dynamically discovered from AICore model registry. @@ -1226,37 +1573,54 @@ def _createAicoreProviderRules(db: DatabaseConnector) -> None: providerRules = [] - # All roles get access to all providers (as per requirement) - for roleLabel in ["admin", "user", "viewer"]: - roleId = _getRoleId(db, roleLabel) - if not roleId: - continue - + # Admin: access to ALL providers + adminId = _getRoleId(db, "admin") + if adminId: for provider in providers: resourceKey = f"resource.aicore.{provider}" - - # Check if rule already exists existingRules = db.getRecordset( AccessRule, recordFilter={ - "roleId": roleId, + "roleId": adminId, "context": AccessRuleContext.RESOURCE.value, "item": resourceKey } ) - if not existingRules: providerRules.append(AccessRule( - roleId=roleId, + roleId=adminId, context=AccessRuleContext.RESOURCE, item=resourceKey, - view=True, # view=True means "can use" for RESOURCE context - read=None, - create=None, - update=None, - delete=None, + view=True, + read=None, create=None, update=None, delete=None, )) + # User: access to all providers EXCEPT anthropic + userId = _getRoleId(db, "user") + if userId: + for provider in providers: + resourceKey = f"resource.aicore.{provider}" + existingRules = db.getRecordset( + AccessRule, + recordFilter={ + "roleId": userId, + "context": AccessRuleContext.RESOURCE.value, + "item": resourceKey + } + ) + if not existingRules: + # Anthropic is not allowed for user role + isAllowed = provider != "anthropic" + providerRules.append(AccessRule( + roleId=userId, + context=AccessRuleContext.RESOURCE, + item=resourceKey, + view=isAllowed, + read=None, create=None, update=None, delete=None, + )) + + # Viewer: NO provider access (viewer has no RESOURCE permissions at all) + for rule in providerRules: db.recordCreate(AccessRule, rule) @@ -1273,7 +1637,7 @@ def initRootMandateBilling(mandateId: str) -> None: """ Initialize billing settings for root mandate. Root mandate uses PREPAY_USER model with 10 CHF initial credit per user. - Also creates billing accounts for all users of the mandate. + Creates billing accounts for ALL users regardless of billing model (for audit trail). Args: mandateId: Root mandate ID @@ -1291,11 +1655,10 @@ def initRootMandateBilling(mandateId: str) -> None: if existingSettings: logger.info("Billing settings for root mandate already exist") else: - # Create billing settings for root mandate settings = BillingSettings( mandateId=mandateId, billingModel=BillingModelEnum.PREPAY_USER, - defaultUserCredit=10.0, # 10 CHF initial credit per user + defaultUserCredit=10.0, warningThresholdPercent=10.0, blockOnZeroBalance=True, notifyOnWarning=True @@ -1305,28 +1668,34 @@ def initRootMandateBilling(mandateId: str) -> None: logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") existingSettings = billingInterface.getSettings(mandateId) - # Create billing accounts for all users of the mandate + # Always create user accounts for all users (audit trail) if existingSettings: billingModel = existingSettings.get("billingModel", "UNLIMITED") + if billingModel == BillingModelEnum.UNLIMITED.value: + return # No accounts needed for UNLIMITED + + # Initial balance depends on billing model if billingModel == BillingModelEnum.PREPAY_USER.value: - defaultCredit = existingSettings.get("defaultUserCredit", 10.0) - userMandates = appInterface.getUserMandatesByMandate(mandateId) - accountsCreated = 0 - - for um in userMandates: - userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None) - if userId: - existingAccount = billingInterface.getUserAccount(mandateId, userId) - if not existingAccount: - billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) - accountsCreated += 1 - logger.debug(f"Created billing account for user {userId}") - - if accountsCreated > 0: - logger.info(f"Created {accountsCreated} billing accounts for root mandate users with {defaultCredit} CHF each") + initialBalance = existingSettings.get("defaultUserCredit", 10.0) + else: + initialBalance = 0.0 # PREPAY_MANDATE / CREDIT_POSTPAY: budget on pool + + userMandates = appInterface.getUserMandatesByMandate(mandateId) + accountsCreated = 0 + + for um in userMandates: + userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None) + if userId: + existingAccount = billingInterface.getUserAccount(mandateId, userId) + if not existingAccount: + billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=initialBalance) + accountsCreated += 1 + logger.debug(f"Created billing account for user {userId}") + + if accountsCreated > 0: + logger.info(f"Created {accountsCreated} billing accounts for root mandate users with {initialBalance} CHF each") except Exception as e: - # Don't fail bootstrap if billing init fails logger.warning(f"Failed to initialize root mandate billing (non-critical): {e}") @@ -1349,10 +1718,14 @@ def assignInitialUserMemberships( adminUserId: Admin user ID eventUserId: Event user ID """ - # Use "admin" role for mandate membership (SysAdmin is a flag, not a role!) - adminRoleId = _getRoleId(db, "admin") + # Use mandate-instance "admin" role (not the global template) + mandateAdminRoles = db.getRecordset( + Role, + recordFilter={"roleLabel": "admin", "mandateId": mandateId, "featureInstanceId": None} + ) + adminRoleId = mandateAdminRoles[0].get("id") if mandateAdminRoles else None if not adminRoleId: - logger.warning("Admin role not found, skipping membership assignment") + logger.warning(f"Admin role not found for mandate {mandateId}, skipping membership assignment") return for userId, userName in [(adminUserId, "admin"), (eventUserId, "event")]: @@ -1364,7 +1737,6 @@ def assignInitialUserMemberships( if existingMemberships: userMandateId = existingMemberships[0].get("id") - logger.debug(f"UserMandate already exists for {userName} user") else: # Create UserMandate userMandate = UserMandate( diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 1d8359a5..68fee415 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -153,9 +153,6 @@ class AppObjects: userId=self.userId, ) - # Initialize database system - self.db.initDbSystem() - logger.info(f"Database initialized successfully for user {self.userId}") except Exception as e: logger.error(f"Failed to initialize database: {str(e)}") @@ -482,17 +479,12 @@ class AppObjects: """Returns the initial ID for a table.""" return self.db.getInitialId(model_class) - def _getDefaultMandateId(self) -> str: - """Get the default mandate ID, creating it if necessary.""" - defaultMandateId = self.getInitialId(Mandate) - if not defaultMandateId: - # If no default mandate exists, create one - logger.warning("No default mandate found, creating Root mandate") - self._initRootMandate() - defaultMandateId = self.getInitialId(Mandate) - if not defaultMandateId: - raise ValueError("Failed to get or create default mandate") - return defaultMandateId + def _getRootMandateId(self) -> Optional[str]: + """Get the root mandate ID (name='root', isSystem=True).""" + rootMandates = self.db.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True}) + if rootMandates: + return rootMandates[0].get("id") + return None def _getPasswordHash(self, password: str) -> str: """Creates a hash for a password.""" @@ -757,8 +749,9 @@ class AppObjects: # Clear cache to ensure fresh data (already done above) - # Grant access to root mandate features (chatplayground, automation) - self._grantRootMandateFeatureAccess(createdUser[0]["id"]) + # Assign new user to the root mandate with system 'viewer' role + userId = createdUser[0]["id"] + self._assignUserToRootMandate(userId) return User(**createdUser[0]) @@ -823,98 +816,47 @@ class AppObjects: logger.error(f"Error updating user: {str(e)}") raise ValueError(f"Failed to update user: {str(e)}") - def _grantRootMandateFeatureAccess(self, userId: str) -> None: + def _assignUserToRootMandate(self, userId: str) -> None: """ - Grant a new user access to root mandate features (chatplayground, automation). - Creates FeatureAccess with viewer role for each feature instance. + Assign a new user to the root mandate with the mandate-instance 'viewer' role. + This ensures every user has a base membership in the system mandate. + + Uses the mandate-instance role (mandateId=rootMandateId), not the global template. + Feature instance access is NOT granted here - it is managed separately + via invitations or admin assignment. Args: - userId: User ID to grant access to + userId: User ID to assign """ try: - from modules.datamodels.datamodelFeatures import FeatureInstance - from modules.datamodels.datamodelMembership import FeatureAccess, FeatureAccessRole from modules.datamodels.datamodelRbac import Role - # Get root mandate ID (first mandate in system) - allMandates = self.db.getRecordset(Mandate) - if not allMandates: - logger.debug("No mandates found, skipping feature access grant") - return - rootMandateId = allMandates[0].get("id") - - # Feature codes to grant access to - rootFeatureCodes = ["chatplayground", "automation"] - - # Get feature instances for root mandate - allInstances = self.db.getRecordset(FeatureInstance) - featureInstances = [ - inst for inst in allInstances - if inst.get("mandateId") == rootMandateId - and inst.get("featureCode") in rootFeatureCodes - and inst.get("enabled") == True - ] - - if not featureInstances: - logger.debug("No root mandate feature instances found, skipping feature access grant") + rootMandateId = self._getRootMandateId() + if not rootMandateId: + logger.warning("No root mandate found, skipping root mandate assignment") return - # Grant access to each feature instance - for instance in featureInstances: - instanceId = instance.get("id") - featureCode = instance.get("featureCode") - - # Check if user already has access - existingAccess = self.db.getRecordset( - FeatureAccess, - recordFilter={ - "userId": userId, - "featureInstanceId": instanceId - } - ) - - if existingAccess: - logger.debug(f"User {userId} already has access to feature instance {instanceId}") - continue - - # Create FeatureAccess - featureAccess = FeatureAccess( - userId=userId, - featureInstanceId=instanceId, - enabled=True - ) - createdAccess = self.db.recordCreate(FeatureAccess, featureAccess.model_dump()) - - if not createdAccess: - logger.warning(f"Failed to create FeatureAccess for user {userId} to instance {instanceId}") - continue - - featureAccessId = createdAccess.get("id") - - # Get viewer role for this feature instance - allRoles = self.db.getRecordset(Role) - viewerRoles = [ - r for r in allRoles - if r.get("featureInstanceId") == instanceId - and r.get("roleLabel") == "viewer" - ] - - if viewerRoles: - # Create FeatureAccessRole junction - featureAccessRole = FeatureAccessRole( - featureAccessId=featureAccessId, - roleId=viewerRoles[0].get("id") - ) - self.db.recordCreate(FeatureAccessRole, featureAccessRole.model_dump()) - logger.debug(f"Granted viewer role for {featureCode} to user {userId}") - else: - logger.warning(f"No viewer role found for feature instance {instanceId} ({featureCode})") + # Check if user already has a mandate membership + existing = self.getUserMandate(userId, rootMandateId) + if existing: + logger.debug(f"User {userId} already assigned to root mandate") + return - logger.info(f"Granted root mandate feature access to user {userId}") + # Find the mandate-instance 'viewer' role (bound to this mandate, not a global template) + mandateViewerRoles = self.db.getRecordset( + Role, + recordFilter={"roleLabel": "viewer", "mandateId": rootMandateId, "featureInstanceId": None} + ) + viewerRoleId = mandateViewerRoles[0].get("id") if mandateViewerRoles else None + + roleIds = [viewerRoleId] if viewerRoleId else [] + + self.createUserMandate(userId, rootMandateId, roleIds) + logger.info(f"Assigned user {userId} to root mandate with viewer role") except Exception as e: # Log but don't fail user creation - logger.error(f"Error granting root mandate feature access to user {userId}: {e}") + logger.error(f"Error assigning user {userId} to root mandate: {e}") def disableUser(self, userId: str) -> User: """Disables a user if current user has permission.""" @@ -1500,7 +1442,10 @@ class AppObjects: return Mandate(**filteredMandates[0]) def createMandate(self, name: str, description: str = None, enabled: bool = True) -> Mandate: - """Creates a new mandate if user has permission.""" + """ + Creates a new mandate if user has permission. + Automatically copies system template roles (admin, user, viewer) to the new mandate. + """ if not self.checkRbacPermission(Mandate, "create"): raise PermissionError("No permission to create mandates") @@ -1512,6 +1457,16 @@ class AppObjects: if not createdRecord or not createdRecord.get("id"): raise ValueError("Failed to create mandate record") + mandateId = createdRecord.get("id") + + # Copy system template roles to new mandate (admin, user, viewer + AccessRules) + try: + from modules.interfaces.interfaceBootstrap import copySystemRolesToMandate + copiedCount = copySystemRolesToMandate(self.db, mandateId) + logger.info(f"Copied {copiedCount} system roles to new mandate {mandateId}") + except Exception as e: + logger.error(f"Error copying system roles to mandate {mandateId}: {e}") + return Mandate(**createdRecord) def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate: @@ -1526,9 +1481,13 @@ class AppObjects: if not mandate: raise ValueError(f"Mandate {mandateId} not found") + # Strip immutable/protected fields from update data + _protectedFields = {"id", "isSystem"} + _sanitizedData = {k: v for k, v in updateData.items() if k not in _protectedFields} + # Update mandate data using model updatedData = mandate.model_dump() - updatedData.update(updateData) + updatedData.update(_sanitizedData) updatedMandate = Mandate(**updatedData) # Update mandate record @@ -1548,13 +1507,17 @@ class AppObjects: raise ValueError(f"Failed to update mandate: {str(e)}") def deleteMandate(self, mandateId: str) -> bool: - """Deletes a mandate if user has access.""" + """Deletes a mandate if user has access. System mandates cannot be deleted.""" try: # Check if mandate exists and user has access mandate = self.getMandate(mandateId) if not mandate: return False + # System mandates (isSystem=True) cannot be deleted + if getattr(mandate, "isSystem", False): + raise ValueError(f"System mandate '{mandate.name}' cannot be deleted") + if not self.checkRbacPermission(Mandate, "delete", mandateId): raise PermissionError(f"No permission to delete mandate {mandateId}") @@ -1677,7 +1640,10 @@ class AppObjects: def _ensureUserBillingAccount(self, userId: str, mandateId: str) -> None: """ Ensure a user has a billing account for the mandate if billing is configured. - Creates account with default credit from settings if billingModel is PREPAY_USER. + User accounts are always created for all billing models (for audit trail). + Initial balance depends on billing model: + - PREPAY_USER: defaultUserCredit from settings + - PREPAY_MANDATE / CREDIT_POSTPAY: 0.0 (budget is on mandate pool) Args: userId: User ID @@ -1694,15 +1660,19 @@ class AppObjects: return # No billing configured for this mandate billingModel = settings.get("billingModel", "UNLIMITED") - if billingModel != BillingModelEnum.PREPAY_USER.value: - return # Only create user accounts for PREPAY_USER model + if billingModel == BillingModelEnum.UNLIMITED.value: + return # No accounts needed for UNLIMITED - defaultCredit = settings.get("defaultUserCredit", 10.0) - billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) - logger.info(f"Created billing account for user {userId} in mandate {mandateId} with {defaultCredit} CHF") + # Initial balance depends on billing model + if billingModel == BillingModelEnum.PREPAY_USER.value: + initialBalance = settings.get("defaultUserCredit", 10.0) + else: + initialBalance = 0.0 # PREPAY_MANDATE / CREDIT_POSTPAY: budget is on pool + + billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=initialBalance) + logger.info(f"Ensured billing account for user {userId} in mandate {mandateId} (model={billingModel}, initial={initialBalance} CHF)") except Exception as e: - # Don't fail user mandate creation if billing account creation fails logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}") def deleteUserMandate(self, userId: str, mandateId: str) -> bool: diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index a8cbd61b..41be662c 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -417,7 +417,12 @@ class BillingObjects: def ensureAllUserAccountsExist(self) -> int: """ - Efficiently ensure all users across all mandates have billing accounts. + Ensure all users across all mandates have billing accounts. + User accounts are always created regardless of billing model (for audit trail). + Initial balance depends on billing model: + - PREPAY_USER: defaultUserCredit from settings + - PREPAY_MANDATE / CREDIT_POSTPAY: 0.0 (budget is on pool) + Uses bulk queries to minimize database connections. Returns: @@ -426,29 +431,31 @@ class BillingObjects: try: accountsCreated = 0 - # Step 1: Get all billing settings in one query (only PREPAY_USER mandates need user accounts) + # Step 1: Get all billing settings (all models except UNLIMITED need user accounts) allSettings = self.db.getRecordset(BillingSettings) - prepayUserMandates = {} + billingMandates = {} # mandateId -> (billingModel, defaultCredit) for s in allSettings: - if s.get("billingModel") == BillingModelEnum.PREPAY_USER.value: - prepayUserMandates[s.get("mandateId")] = s.get("defaultUserCredit", 10.0) + billingModel = s.get("billingModel", BillingModelEnum.UNLIMITED.value) + if billingModel == BillingModelEnum.UNLIMITED.value: + continue + defaultCredit = s.get("defaultUserCredit", 10.0) if billingModel == BillingModelEnum.PREPAY_USER.value else 0.0 + billingMandates[s.get("mandateId")] = (billingModel, defaultCredit) - if not prepayUserMandates: - logger.debug("No PREPAY_USER mandates found, skipping account check") + if not billingMandates: + logger.debug("No billable mandates found, skipping account check") return 0 - # Step 2: Get all existing USER accounts in one query (from billing DB) + # Step 2: Get all existing USER accounts in one query allAccounts = self.db.getRecordset( BillingAccount, recordFilter={"accountType": AccountTypeEnum.USER.value} ) - # Build set of existing (mandateId, userId) pairs existingAccountKeys = set() for acc in allAccounts: key = (acc.get("mandateId"), acc.get("userId")) existingAccountKeys.add(key) - # Step 3: Get all user-mandate combinations from APP database (separate connection) + # Step 3: Get all user-mandate combinations from APP database appDb = DatabaseConnector( dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'), dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), @@ -461,7 +468,7 @@ class BillingObjects: recordFilter={"enabled": True} ) - # Step 4: Find missing accounts and create them + # Step 4: Create missing accounts for um in allUserMandates: mandateId = um.get("mandateId") userId = um.get("userId") @@ -469,17 +476,15 @@ class BillingObjects: if not mandateId or not userId: continue - # Only process mandates with PREPAY_USER billing - if mandateId not in prepayUserMandates: + if mandateId not in billingMandates: continue - # Check if account already exists (in memory, no DB call) key = (mandateId, userId) if key in existingAccountKeys: continue - # Create missing account - defaultCredit = prepayUserMandates[mandateId] + billingModel, defaultCredit = billingMandates[mandateId] + account = BillingAccount( mandateId=mandateId, userId=userId, @@ -489,7 +494,6 @@ class BillingObjects: ) created = self.createAccount(account) - # Create initial credit transaction if defaultCredit > 0: self.createTransaction(BillingTransaction( accountId=created["id"], @@ -499,7 +503,7 @@ class BillingObjects: referenceType=ReferenceTypeEnum.SYSTEM )) - existingAccountKeys.add(key) # Track newly created + existingAccountKeys.add(key) accountsCreated += 1 if accountsCreated > 0: @@ -515,22 +519,37 @@ class BillingObjects: # BillingTransaction Operations # ========================================================================= - def createTransaction(self, transaction: BillingTransaction) -> Dict[str, Any]: + def createTransaction(self, transaction: BillingTransaction, balanceAccountId: str = None) -> Dict[str, Any]: """ Create a new billing transaction and update account balance. + The transaction is always recorded against transaction.accountId (audit trail). + The balance is updated on balanceAccountId if provided, otherwise on transaction.accountId. + This allows recording a transaction on a user account (audit) while deducting + from a mandate pool account (shared budget). + Args: transaction: BillingTransaction object + balanceAccountId: Optional account ID for balance update (defaults to transaction.accountId) Returns: Created transaction dict """ - # Get current account - account = self.getAccount(transaction.accountId) - if not account: - raise ValueError(f"Account {transaction.accountId} not found") + # Validate that the transaction's account exists + txAccount = self.getAccount(transaction.accountId) + if not txAccount: + raise ValueError(f"Transaction account {transaction.accountId} not found") - currentBalance = account.get("balance", 0.0) + # Determine which account to update balance on + targetBalanceAccountId = balanceAccountId or transaction.accountId + if targetBalanceAccountId == transaction.accountId: + balanceAccount = txAccount + else: + balanceAccount = self.getAccount(targetBalanceAccountId) + if not balanceAccount: + raise ValueError(f"Balance account {targetBalanceAccountId} not found") + + currentBalance = balanceAccount.get("balance", 0.0) # Calculate new balance if transaction.transactionType == TransactionTypeEnum.CREDIT: @@ -538,17 +557,17 @@ class BillingObjects: elif transaction.transactionType == TransactionTypeEnum.DEBIT: newBalance = currentBalance - transaction.amount else: # ADJUSTMENT - newBalance = currentBalance + transaction.amount # Can be positive or negative + newBalance = currentBalance + transaction.amount - # Create transaction + # Create transaction record (always on transaction.accountId for audit) transactionDict = transaction.model_dump(exclude_none=True) created = self.db.recordCreate(BillingTransaction, transactionDict) - # Update account balance - self.updateAccountBalance(transaction.accountId, newBalance) + # Update balance on the target account + self.updateAccountBalance(targetBalanceAccountId, newBalance) logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, " - f"balance: {currentBalance} -> {newBalance}") + f"audit={transaction.accountId}, balance on {targetBalanceAccountId}: {currentBalance} -> {newBalance}") return created @@ -631,6 +650,14 @@ class BillingObjects: """ Check if there's sufficient balance for an operation. + Budget logic: + - PREPAY_USER: check user's own account balance + - PREPAY_MANDATE: check mandate pool balance (shared by all users) + - CREDIT_POSTPAY: check mandate pool credit limit + - UNLIMITED: always allowed + + User accounts are always ensured to exist (for audit trail). + Args: mandateId: Mandate ID userId: User ID @@ -641,43 +668,29 @@ class BillingObjects: """ settings = self.getSettings(mandateId) if not settings: - # No settings = no billing = allowed return BillingCheckResult(allowed=True, billingModel=BillingModelEnum.UNLIMITED) billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) - # UNLIMITED = always allowed if billingModel == BillingModelEnum.UNLIMITED: return BillingCheckResult(allowed=True, billingModel=billingModel) - # Get the relevant account + # Always ensure user account exists (for audit trail) + defaultCredit = settings.get("defaultUserCredit", 10.0) + initialBalance = defaultCredit if billingModel == BillingModelEnum.PREPAY_USER else 0.0 + self.getOrCreateUserAccount(mandateId, userId, initialBalance=initialBalance) + + # Determine which balance to check based on billing model if billingModel == BillingModelEnum.PREPAY_USER: account = self.getUserAccount(mandateId, userId) - # Auto-create user account if not exists (with default credit from settings) - if not account: - defaultCredit = settings.get("defaultUserCredit", 10.0) - logger.info(f"Auto-creating billing account for user {userId} in mandate {mandateId} with {defaultCredit} CHF initial credit") - account = self.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) - else: - account = self.getMandateAccount(mandateId) - - if not account: - # No account (only happens for mandate-level accounts) = potentially blocked - if settings.get("blockOnZeroBalance", True): - return BillingCheckResult( - allowed=False, - reason="NO_ACCOUNT", - currentBalance=0.0, - requiredAmount=estimatedCost, - billingModel=billingModel - ) - return BillingCheckResult(allowed=True, currentBalance=0.0, billingModel=billingModel) - - currentBalance = account.get("balance", 0.0) - - # CREDIT_POSTPAY with credit limit check - if billingModel == BillingModelEnum.CREDIT_POSTPAY: - creditLimit = account.get("creditLimit") + currentBalance = account.get("balance", 0.0) if account else 0.0 + elif billingModel == BillingModelEnum.PREPAY_MANDATE: + poolAccount = self.getOrCreateMandateAccount(mandateId) + currentBalance = poolAccount.get("balance", 0.0) + elif billingModel == BillingModelEnum.CREDIT_POSTPAY: + poolAccount = self.getOrCreateMandateAccount(mandateId) + currentBalance = poolAccount.get("balance", 0.0) + creditLimit = poolAccount.get("creditLimit") if creditLimit and abs(currentBalance) + estimatedCost > creditLimit: return BillingCheckResult( allowed=False, @@ -687,6 +700,8 @@ class BillingObjects: billingModel=billingModel ) return BillingCheckResult(allowed=True, currentBalance=currentBalance, billingModel=billingModel) + else: + return BillingCheckResult(allowed=True, billingModel=billingModel) # PREPAY models - check balance if currentBalance < estimatedCost: @@ -716,6 +731,12 @@ class BillingObjects: """ Record usage cost as a billing transaction. + Transaction is ALWAYS recorded on the user's account (clean audit trail). + Balance is deducted from the appropriate account based on billing model: + - PREPAY_USER: deduct from user's own balance + - PREPAY_MANDATE: deduct from mandate pool balance + - CREDIT_POSTPAY: deduct from mandate pool balance + Args: mandateId: Mandate ID userId: User ID @@ -740,19 +761,14 @@ class BillingObjects: billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) - # UNLIMITED = no transaction recording if billingModel == BillingModelEnum.UNLIMITED: return None - # Get or create the relevant account - if billingModel == BillingModelEnum.PREPAY_USER: - account = self.getOrCreateUserAccount(mandateId, userId) - else: - account = self.getOrCreateMandateAccount(mandateId) + # Transaction is ALWAYS on the user's account (audit trail) + userAccount = self.getOrCreateUserAccount(mandateId, userId) - # Create debit transaction transaction = BillingTransaction( - accountId=account["id"], + accountId=userAccount["id"], transactionType=TransactionTypeEnum.DEBIT, amount=priceCHF, description=description, @@ -765,7 +781,84 @@ class BillingObjects: createdByUserId=userId ) - return self.createTransaction(transaction) + # Determine where to deduct balance + if billingModel == BillingModelEnum.PREPAY_USER: + # Deduct from user's own balance + return self.createTransaction(transaction) + else: + # PREPAY_MANDATE / CREDIT_POSTPAY: deduct from mandate pool + poolAccount = self.getOrCreateMandateAccount(mandateId) + return self.createTransaction(transaction, balanceAccountId=poolAccount["id"]) + + # ========================================================================= + # Billing Model Switch Operations + # ========================================================================= + + def switchBillingModel(self, mandateId: str, oldModel: BillingModelEnum, newModel: BillingModelEnum) -> Dict[str, Any]: + """ + Switch billing model with automatic budget migration. + + MANDATE -> USER: pool balance is distributed equally to all user accounts. + USER -> MANDATE: all user balances are consolidated into the pool, user balances set to 0. + + Args: + mandateId: Mandate ID + oldModel: Current billing model + newModel: New billing model + + Returns: + Migration result dict with details + """ + result = {"oldModel": oldModel.value, "newModel": newModel.value, "migratedAmount": 0.0, "userCount": 0} + + if oldModel == newModel: + return result + + if oldModel == BillingModelEnum.PREPAY_MANDATE and newModel == BillingModelEnum.PREPAY_USER: + # Pool -> distribute equally to users + poolAccount = self.getMandateAccount(mandateId) + if poolAccount and poolAccount.get("balance", 0.0) > 0: + poolBalance = poolAccount["balance"] + userAccounts = self.db.getRecordset( + BillingAccount, + recordFilter={"mandateId": mandateId, "accountType": AccountTypeEnum.USER.value} + ) + if userAccounts: + perUser = poolBalance / len(userAccounts) + for acc in userAccounts: + newBalance = acc.get("balance", 0.0) + perUser + self.updateAccountBalance(acc["id"], newBalance) + self.updateAccountBalance(poolAccount["id"], 0.0) + result["migratedAmount"] = poolBalance + result["userCount"] = len(userAccounts) + + logger.info(f"Switched {mandateId} MANDATE->USER: distributed {result['migratedAmount']} CHF to {result['userCount']} users") + + elif oldModel == BillingModelEnum.PREPAY_USER and newModel == BillingModelEnum.PREPAY_MANDATE: + # Users -> consolidate into pool + userAccounts = self.db.getRecordset( + BillingAccount, + recordFilter={"mandateId": mandateId, "accountType": AccountTypeEnum.USER.value} + ) + totalUserBalance = sum(acc.get("balance", 0.0) for acc in userAccounts) + + poolAccount = self.getOrCreateMandateAccount(mandateId, initialBalance=0.0) + newPoolBalance = poolAccount.get("balance", 0.0) + totalUserBalance + self.updateAccountBalance(poolAccount["id"], newPoolBalance) + + for acc in userAccounts: + self.updateAccountBalance(acc["id"], 0.0) + + result["migratedAmount"] = totalUserBalance + result["userCount"] = len(userAccounts) + + logger.info(f"Switched {mandateId} USER->MANDATE: consolidated {totalUserBalance} CHF from {len(userAccounts)} users into pool") + + elif newModel == BillingModelEnum.PREPAY_MANDATE or newModel == BillingModelEnum.CREDIT_POSTPAY: + # Any -> MANDATE/CREDIT: ensure pool account exists + self.getOrCreateMandateAccount(mandateId, initialBalance=0.0) + + return result # ========================================================================= # Statistics Operations @@ -862,6 +955,11 @@ class BillingObjects: """ Get all billing balances for a user across mandates. + Shows the effective available budget: + - PREPAY_USER: user's own account balance + - PREPAY_MANDATE: mandate pool balance (shared budget visible to user) + - CREDIT_POSTPAY: mandate pool balance + Args: userId: User ID @@ -872,13 +970,11 @@ class BillingObjects: balances = [] - # Get all mandates the user belongs to try: appInterface = getAppInterface(self.currentUser) userMandates = appInterface.getUserMandates(userId) for um in userMandates: - # Handle both Pydantic models and dicts mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) if not mandateId: continue @@ -887,7 +983,6 @@ class BillingObjects: if not mandate: continue - # Get mandate name (handle both Pydantic and dict) mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") settings = self.getSettings(mandateId) @@ -895,21 +990,27 @@ class BillingObjects: continue billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + if billingModel == BillingModelEnum.UNLIMITED: + continue - # Get the relevant account + # Determine effective balance based on billing model if billingModel == BillingModelEnum.PREPAY_USER: account = self.getUserAccount(mandateId, userId) + if not account: + continue + balance = account.get("balance", 0.0) + warningThreshold = account.get("warningThreshold", 0.0) + creditLimit = account.get("creditLimit") elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]: - account = self.getMandateAccount(mandateId) + poolAccount = self.getMandateAccount(mandateId) + if not poolAccount: + continue + balance = poolAccount.get("balance", 0.0) + warningThreshold = poolAccount.get("warningThreshold", 0.0) + creditLimit = poolAccount.get("creditLimit") else: continue - if not account: - continue - - balance = account.get("balance", 0.0) - warningThreshold = account.get("warningThreshold", 0.0) - balances.append(BillingBalanceResponse( mandateId=mandateId, mandateName=mandateName, @@ -917,7 +1018,7 @@ class BillingObjects: balance=balance, warningThreshold=warningThreshold, isWarning=balance <= warningThreshold, - creditLimit=account.get("creditLimit") + creditLimit=creditLimit )) except Exception as e: logger.error(f"Error getting balances for user: {e}") @@ -927,6 +1028,8 @@ class BillingObjects: def getTransactionsForUser(self, userId: str, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for a user across all mandates they belong to. + Since transactions are always recorded on user accounts, we query + directly by user account - clean and simple. Args: userId: User ID @@ -944,20 +1047,22 @@ class BillingObjects: userMandates = appInterface.getUserMandates(userId) for um in userMandates: - # Handle both Pydantic models and dicts mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) if not mandateId: continue - # Only include mandates with billing settings settings = self.getSettings(mandateId) if not settings: continue - # Get transactions for this mandate - transactions = self.getTransactionsByMandate(mandateId, limit=limit) + # Get user's account in this mandate + userAccount = self.getUserAccount(mandateId, userId) + if not userAccount: + continue + + # Get transactions for user's account (all transactions are on user accounts now) + transactions = self.getTransactions(userAccount["id"], limit=limit) - # Add mandate context to each transaction mandate = appInterface.getMandate(mandateId) mandateName = "" if mandate: @@ -971,7 +1076,6 @@ class BillingObjects: except Exception as e: logger.error(f"Error getting transactions for user: {e}") - # Sort by creation date descending and limit allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) return allTransactions[:limit] @@ -1016,23 +1120,23 @@ class BillingObjects: if mandate: mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") - # For PREPAY_MANDATE, get the mandate account balance - # For PREPAY_USER, aggregate all user balances - if billingModel == BillingModelEnum.PREPAY_MANDATE: - account = self.getMandateAccount(mandateId) - totalBalance = account.get("balance", 0.0) if account else 0.0 - userCount = 0 - elif billingModel == BillingModelEnum.PREPAY_USER: - # Get all user accounts for this mandate - userAccounts = self.db.getRecordset( - BillingAccount, - recordFilter={"mandateId": mandateId, "accountType": AccountTypeEnum.USER.value} - ) + # Get user accounts count (always exist now for audit trail) + userAccounts = self.db.getRecordset( + BillingAccount, + recordFilter={"mandateId": mandateId, "accountType": AccountTypeEnum.USER.value} + ) + userCount = len(userAccounts) + + # Total balance depends on billing model + if billingModel == BillingModelEnum.PREPAY_USER: + # Budget is distributed across user accounts totalBalance = sum(acc.get("balance", 0.0) for acc in userAccounts) - userCount = len(userAccounts) + elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]: + # Budget is in the mandate pool + poolAccount = self.getMandateAccount(mandateId) + totalBalance = poolAccount.get("balance", 0.0) if poolAccount else 0.0 else: totalBalance = 0.0 - userCount = 0 balances.append({ "mandateId": mandateId, @@ -1183,7 +1287,8 @@ class BillingObjects: def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: """ - Get all transactions for specified mandates (both USER and MANDATE accounts). + Get all transactions for specified mandates. + All usage transactions are on user accounts (audit trail). Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. diff --git a/modules/interfaces/interfaceDbChat.py b/modules/interfaces/interfaceDbChat.py index 9ee20fc0..e7925dbd 100644 --- a/modules/interfaces/interfaceDbChat.py +++ b/modules/interfaces/interfaceDbChat.py @@ -329,9 +329,6 @@ class ChatObjects: userId=self.userId ) - # Initialize database system - self.db.initDbSystem() - logger.info("Database initialized successfully") except Exception as e: logger.error(f"Failed to initialize database: {str(e)}") diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py index 10c47a19..59fc4c1c 100644 --- a/modules/interfaces/interfaceDbManagement.py +++ b/modules/interfaces/interfaceDbManagement.py @@ -141,9 +141,6 @@ class ComponentObjects: userId=self.userId if hasattr(self, 'userId') else None ) - # Initialize database system - self.db.initDbSystem() - logger.info("Database initialized successfully") except Exception as e: logger.error(f"Failed to initialize database: {str(e)}") diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index 3e87cd83..e785ed13 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -455,11 +455,8 @@ def getStatistics( billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) - # Get the relevant account - if billingModel == BillingModelEnum.PREPAY_USER: - account = billingInterface.getUserAccount(ctx.mandateId, ctx.user.id) - else: - account = billingInterface.getMandateAccount(ctx.mandateId) + # Transactions are always on user accounts (audit trail) + account = billingInterface.getUserAccount(ctx.mandateId, ctx.user.id) if not account: return UsageReportResponse( @@ -578,14 +575,20 @@ def createOrUpdateSettings( existingSettings = billingInterface.getSettings(targetMandateId) if existingSettings: - # Update existing settings updates = settingsUpdate.model_dump(exclude_none=True) if updates: + # Check if billing model is changing - trigger budget migration + if "billingModel" in updates: + oldModel = BillingModelEnum(existingSettings.get("billingModel", BillingModelEnum.UNLIMITED.value)) + newModel = BillingModelEnum(updates["billingModel"]) if isinstance(updates["billingModel"], str) else updates["billingModel"] + if oldModel != newModel: + migrationResult = billingInterface.switchBillingModel(targetMandateId, oldModel, newModel) + logger.info(f"Billing model migration for {targetMandateId}: {migrationResult}") + result = billingInterface.updateSettings(existingSettings["id"], updates) return result or existingSettings return existingSettings else: - # Create new settings from modules.datamodels.datamodelBilling import BillingSettings newSettings = BillingSettings( diff --git a/modules/routes/routeInvitations.py b/modules/routes/routeInvitations.py index 095b84fb..01c395e2 100644 --- a/modules/routes/routeInvitations.py +++ b/modules/routes/routeInvitations.py @@ -41,6 +41,7 @@ class InvitationCreate(BaseModel): email: Optional[str] = Field(None, description="Email address to send invitation link (optional)") roleIds: List[str] = Field(..., description="Role IDs to assign to the invited user") featureInstanceId: Optional[str] = Field(None, description="Optional feature instance access") + frontendUrl: str = Field(..., description="Frontend URL for building the invite link (provided by frontend)") expiresInHours: int = Field( 72, ge=1, @@ -178,10 +179,9 @@ def create_invitation( if not createdRecord: raise ValueError("Failed to create invitation record") - # Build invite URL - from modules.shared.configuration import APP_CONFIG - frontendUrl = APP_CONFIG.get("APP_FRONTEND_URL", "http://localhost:8080") - inviteUrl = f"{frontendUrl}/invite/{invitation.token}" + # Build invite URL using frontend URL provided by the caller + baseUrl = data.frontendUrl.rstrip("/") + inviteUrl = f"{baseUrl}/invite/{invitation.token}" # Send email if email address is provided emailSent = False @@ -302,6 +302,7 @@ def create_invitation( @limiter.limit("60/minute") def list_invitations( request: Request, + frontendUrl: str = Query(..., description="Frontend URL for building invite links (provided by frontend)"), includeUsed: bool = Query(False, description="Include already used invitations"), includeExpired: bool = Query(False, description="Include expired invitations"), context: RequestContext = Depends(getRequestContext) @@ -353,10 +354,9 @@ def list_invitations( if not includeExpired and expiresAt < currentTime: continue - # Build invite URL - from modules.shared.configuration import APP_CONFIG - frontendUrl = APP_CONFIG.get("APP_FRONTEND_URL", "http://localhost:8080") - inviteUrl = f"{frontendUrl}/invite/{inv.token}" + # Build invite URL using frontend URL provided by the caller + baseUrl = frontendUrl.rstrip("/") + inviteUrl = f"{baseUrl}/invite/{inv.token}" result.append({ **inv.model_dump(), diff --git a/modules/security/rbac.py b/modules/security/rbac.py index c661e795..0a2136b1 100644 --- a/modules/security/rbac.py +++ b/modules/security/rbac.py @@ -13,7 +13,7 @@ Multi-Tenant Design: import logging from typing import List, Optional, TYPE_CHECKING from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role -from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel, Mandate +from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel from modules.datamodels.datamodelMembership import ( UserMandate, UserMandateRole, @@ -155,10 +155,16 @@ class RbacClass: ) -> List[str]: """ Get all role IDs for a user in the given context. - Uses UserMandate + UserMandateRole for the new multi-tenant model. + Uses UserMandate + UserMandateRole for the multi-tenant model. - Also includes roles from the Root mandate (first mandate) if different - from the requested mandate, so system-level permissions are always available. + Each mandate has its own instances of system roles (admin, user, viewer) + which are copied from the global templates during mandate creation. + Therefore, only the requested mandate's roles are loaded - no need to + load root mandate roles separately. + + Loads roles from: + 1. The requested mandate (if provided) - includes mandate-instance system roles + 2. Feature instance roles (if featureInstanceId provided) Args: user: User object @@ -171,24 +177,11 @@ class RbacClass: roleIds = set() # Use set to avoid duplicates try: - # Get Root mandate ID (first mandate in system) - allMandates = self.dbApp.getRecordset(Mandate) - rootMandateId = allMandates[0]["id"] if allMandates else None - - # Collect mandates to check: - # - If mandateId provided: current mandate + Root mandate (if different) - # - If no mandateId: just Root mandate (for system-level access) - mandatesToCheck = [] + # Load roles from the requested mandate if mandateId: - mandatesToCheck.append(mandateId) - if rootMandateId and rootMandateId not in mandatesToCheck: - mandatesToCheck.append(rootMandateId) - - # Load roles from each mandate - for checkMandateId in mandatesToCheck: userMandateRecords = self.dbApp.getRecordset( UserMandate, - recordFilter={"userId": user.id, "mandateId": checkMandateId, "enabled": True} + recordFilter={"userId": user.id, "mandateId": mandateId, "enabled": True} ) if userMandateRecords: diff --git a/modules/services/serviceBilling/mainServiceBilling.py b/modules/services/serviceBilling/mainServiceBilling.py index 8bf1c2c4..8407304c 100644 --- a/modules/services/serviceBilling/mainServiceBilling.py +++ b/modules/services/serviceBilling/mainServiceBilling.py @@ -27,8 +27,8 @@ from modules.interfaces.interfaceDbBilling import getInterface as getBillingInte logger = logging.getLogger(__name__) -# Markup percentage for internal pricing (50% = 1.5x) -BILLING_MARKUP_PERCENT = 50 +# Markup percentage for internal pricing (+50% für Infrastruktur und Platform Service + 50% für Währungsrisiko ==> Faktor 2.0) +BILLING_MARKUP_PERCENT = 100 # Singleton cache _billingServices: Dict[str, "BillingService"] = {} diff --git a/modules/shared/eventManagement.py b/modules/shared/eventManagement.py index 1e473ccb..3bb45af8 100644 --- a/modules/shared/eventManagement.py +++ b/modules/shared/eventManagement.py @@ -113,7 +113,7 @@ class EventManagement: self.scheduler.remove_job(jobId) logger.info(f"Removed job '{jobId}'") except Exception as exc: - logger.warning(f"Could not remove job '{jobId}': {exc}") + logger.debug(f"Could not remove job '{jobId}': {exc}") # Singleton instance for easy import and reuse diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py index 01d56eb4..3137604a 100644 --- a/modules/system/mainSystem.py +++ b/modules/system/mainSystem.py @@ -363,6 +363,43 @@ RESOURCE_OBJECTS = [ ] +def _discoverAicoreProviderObjects() -> List[Dict[str, Any]]: + """ + Dynamically discover AICore provider resources for the RBAC catalog. + Providers are discovered from the model registry at startup. + """ + providerLabels = { + "anthropic": {"en": "Anthropic (Claude)", "de": "Anthropic (Claude)", "fr": "Anthropic (Claude)"}, + "openai": {"en": "OpenAI (GPT)", "de": "OpenAI (GPT)", "fr": "OpenAI (GPT)"}, + "perplexity": {"en": "Perplexity", "de": "Perplexity", "fr": "Perplexity"}, + "tavily": {"en": "Tavily (Web Search)", "de": "Tavily (Websuche)", "fr": "Tavily (Recherche Web)"}, + "privatellm": {"en": "Private LLM", "de": "Private LLM", "fr": "LLM Privé"}, + "internal": {"en": "Internal", "de": "Intern", "fr": "Interne"}, + } + + try: + from modules.aicore.aicoreModelRegistry import modelRegistry + connectors = modelRegistry.discoverConnectors() + providers = [c.getConnectorType() for c in connectors] + + objects = [] + for provider in providers: + label = providerLabels.get(provider, {"en": provider, "de": provider, "fr": provider}) + objects.append({ + "objectKey": f"resource.aicore.{provider}", + "label": label, + "meta": {"provider": provider, "category": "aicore"} + }) + + if objects: + logger.info(f"Discovered {len(objects)} AICore provider catalog objects: {providers}") + return objects + + except Exception as e: + logger.warning(f"Failed to discover AICore providers for catalog: {e}") + return [] + + def registerFeature(catalogService) -> bool: """ Register system RBAC objects in the catalog. @@ -401,6 +438,16 @@ def registerFeature(catalogService) -> bool: meta=resObj.get("meta") ) + # Register dynamically discovered AICore provider resources + aicoreObjects = _discoverAicoreProviderObjects() + for aicoreObj in aicoreObjects: + catalogService.registerResourceObject( + featureCode=FEATURE_CODE, + objectKey=aicoreObj["objectKey"], + label=aicoreObj["label"], + meta=aicoreObj.get("meta") + ) + # Register feature definition catalogService.registerFeatureDefinition( featureCode=FEATURE_CODE, diff --git a/modules/system/registry.py b/modules/system/registry.py index f7e50524..1c32badd 100644 --- a/modules/system/registry.py +++ b/modules/system/registry.py @@ -86,22 +86,20 @@ def loadFeatureRouters(app: FastAPI) -> Dict[str, Any]: logger.error(f"Failed to load router from {featureDir}: {e}") results[featureDir] = {"status": "error", "error": str(e)} - # Register features in RBAC catalog and sync template roles to database - from modules.security.rbacCatalog import getCatalogService - catalogService = getCatalogService() - registrationResults = registerAllFeaturesInCatalog(catalogService) - - for featureName, success in registrationResults.items(): - if featureName in results: - results[featureName]["rbac_registered"] = success - return results +_cachedMainModules = None + def loadFeatureMainModules() -> Dict[str, Any]: """ Dynamically load main modules from all discovered feature containers. + Results are cached after the first call. """ + global _cachedMainModules + if _cachedMainModules is not None: + return _cachedMainModules + mainModules = {} pattern = os.path.join(FEATURES_DIR, "*", "main*.py") @@ -113,6 +111,10 @@ def loadFeatureMainModules() -> Dict[str, Any]: featureDir = os.path.basename(os.path.dirname(filepath)) if featureDir.startswith("_"): continue + + # Skip if this feature already has a main module loaded (avoid duplicates) + if featureDir in mainModules: + continue mainFile = filename[:-3] # Remove .py @@ -124,6 +126,7 @@ def loadFeatureMainModules() -> Dict[str, Any]: except Exception as e: logger.error(f"Failed to load main module from {featureDir}: {e}") + _cachedMainModules = mainModules return mainModules diff --git a/modules/workflows/workflowManager.py b/modules/workflows/workflowManager.py index 030a966f..dfc617da 100644 --- a/modules/workflows/workflowManager.py +++ b/modules/workflows/workflowManager.py @@ -188,7 +188,6 @@ class WorkflowManager: detectedLanguage = None # No language detection in automation mode normalizedRequest = userInput.prompt intentText = userInput.prompt - contextItems = [] workflowIntent = None else: # Process user-uploaded documents from userInput for combined analysis @@ -206,7 +205,6 @@ class WorkflowManager: detectedLanguage = analysisResult.get('detectedLanguage') normalizedRequest = analysisResult.get('normalizedRequest') intentText = analysisResult.get('intent') or userInput.prompt - contextItems = analysisResult.get('contextItems', []) complexity = analysisResult.get('complexity', 'moderate') needsWorkflowHistory = analysisResult.get('needsWorkflowHistory', False) fastTrack = analysisResult.get('fastTrack', False) @@ -251,8 +249,6 @@ class WorkflowManager: # Fallback only if normalizedRequest is None or empty logger.warning(f"normalizedRequest is None or empty, falling back to intentText. normalizedRequest={normalizedRequest}, intentText={intentText[:100] if intentText else None}...") self.services.currentUserPromptNormalized = intentText or userInput.prompt - if contextItems is not None: - self.services.currentUserContextItems = contextItems # Set detected language if detectedLanguage and isinstance(detectedLanguage, str): @@ -305,7 +301,6 @@ class WorkflowManager: - detectedLanguage: ISO 639-1 Sprachcode - normalizedRequest: Vollständige, explizite Umformulierung - intent: Kurze Kern-Anfrage - - contextItems: Große Datenblöcke als separate Dokumente - complexity: "simple" | "moderate" | "complex" - needsWorkflowHistory: bool - fastTrack: bool @@ -323,24 +318,22 @@ class WorkflowManager: analysisPrompt = f"""You are an input analyzer. From the user's message, perform ALL of the following in one pass: 1. detectedLanguage: Detect ISO 639-1 language code (e.g., de, en, fr, it) -2. normalizedRequest: Full, explicit restatement of the user's request in the detected language; do NOT summarize; preserve ALL constraints and details +2. normalizedRequest: Full, explicit restatement of the user's request in the detected language; do NOT summarize; preserve ALL constraints and details. Include all data and context from the original message 3. intent: Concise single-paragraph core request in the detected language for high-level routing -4. contextItems: Supportive data blocks to attach as separate documents if significantly larger than the intent (large literal content, long lists/tables, code/JSON blocks, transcripts, CSV fragments, detailed specs). Keep URLs in the intent unless they embed large pasted content -5. complexity: "simple" | "moderate" | "complex" +4. complexity: "simple" | "moderate" | "complex" - "simple": Only if NO documents AND NO web search required. Single question, straightforward answer (5-15s) - "moderate": Multiple steps, some documents, structured response requiring some processing, or web search needed (30-60s) - "complex": Multi-task workflow, many documents, research needed, content generation required, multi-step planning (60-120s) -6. needsWorkflowHistory: Boolean indicating if this request needs previous workflow rounds/history (e.g., 'continue', 'retry', 'fix', 'improve', 'update', 'modify', 'based on previous', 'build on', references to earlier work) -7. fastTrack: Boolean indicating if Fast Track is possible (simple requests without documents and without workflow history) -8. dataType: What type of data/content they want (numbers|text|documents|analysis|code|unknown) -9. expectedFormats: What file format(s) they expect - provide matching file format extensions list (e.g., ["xlsx", "pdf"]). If format is unclear or not specified, use empty list [] -10. qualityRequirements: Quality requirements they have (accuracy, completeness) as {{accuracyThreshold: 0.0-1.0, completenessThreshold: 0.0-1.0}} -11. successCriteria: Specific success criteria that define completion (array of strings) -12. workflowName: Create a concise, descriptive name for this workflow in the detected language. The name should summarize the main task or goal (e.g., "Service Report January 2026", "Email Analysis", "Document Generation"). Keep it short (max 60 characters) and meaningful. +5. needsWorkflowHistory: Boolean indicating if this request needs previous workflow rounds/history (e.g., 'continue', 'retry', 'fix', 'improve', 'update', 'modify', 'based on previous', 'build on', references to earlier work) +6. fastTrack: Boolean indicating if Fast Track is possible (simple requests without documents and without workflow history) +7. dataType: What type of data/content they want (numbers|text|documents|analysis|code|unknown) +8. expectedFormats: What file format(s) they expect - provide matching file format extensions list (e.g., ["xlsx", "pdf"]). If format is unclear or not specified, use empty list [] +9. qualityRequirements: Quality requirements they have (accuracy, completeness) as {{accuracyThreshold: 0.0-1.0, completenessThreshold: 0.0-1.0}} +10. successCriteria: Specific success criteria that define completion (array of strings) +11. workflowName: Create a concise, descriptive name for this workflow in the detected language. The name should summarize the main task or goal (e.g., "Service Report January 2026", "Email Analysis", "Document Generation"). Keep it short (max 60 characters) and meaningful. Rules: -- If total content (intent + data) is < 10% of model max tokens, do not extract; return empty contextItems and keep intent compact and self-contained -- If content exceeds that threshold, move bulky parts into contextItems; keep intent short and clear +- normalizedRequest must contain the COMPLETE restatement including all data references - do NOT strip or extract content - Preserve critical references (URLs, filenames) in intent - Normalize to the primary detected language if mixed-language - Consider number of documents provided when determining complexity @@ -354,13 +347,6 @@ Return ONLY JSON (no markdown) with this exact structure: "detectedLanguage": "de|en|fr|it|...", "normalizedRequest": "Full explicit instruction in detected language", "intent": "Concise normalized request...", - "contextItems": [ - {{ - "title": "User context 1", - "mimeType": "text/plain", - "content": "Full extracted content block here" - }} - ], "complexity": "simple" | "moderate" | "complex", "needsWorkflowHistory": true|false, "fastTrack": true|false, @@ -375,7 +361,7 @@ Return ONLY JSON (no markdown) with this exact structure: }} ## User Message -The following is the user's original input message. Analyze intent, normalize the request, determine complexity, and identify any large context blocks that should be moved to separate documents: +The following is the user's original input message. Analyze intent, normalize the request, and determine complexity: ################ USER INPUT START ################# {userPrompt.replace('{', '{{').replace('}', '}}') if userPrompt else ''} @@ -410,7 +396,6 @@ The following is the user's original input message. Analyze intent, normalize th "detectedLanguage": "en", "normalizedRequest": "", "intent": "", - "contextItems": [], "complexity": "moderate", "needsWorkflowHistory": False, "fastTrack": False, @@ -450,7 +435,41 @@ The following is the user's original input message. Analyze intent, normalize th "taskProgress": "pending", "actionProgress": "pending" } - self.services.chat.storeMessageWithDocuments(workflow, firstMessageData, []) + + # Create user prompt original document + user-uploaded documents for "first" message + firstMessageDocs = [] + if userInput.prompt: + try: + originalPromptBytes = userInput.prompt.encode('utf-8') + originalPromptBytes = await self._neutralizeContentIfEnabled(originalPromptBytes, "text/markdown") + fileItem = self.services.interfaceDbComponent.createFile( + name="user_prompt_original.md", + mimeType="text/markdown", + content=originalPromptBytes + ) + self.services.interfaceDbComponent.createFileData(fileItem.id, originalPromptBytes) + fileInfo = self.services.chat.getFileInfo(fileItem.id) + doc = { + "fileId": fileItem.id, + "fileName": fileInfo.get("fileName", "user_prompt_original.md") if fileInfo else "user_prompt_original.md", + "fileSize": fileInfo.get("size", len(originalPromptBytes)) if fileInfo else len(originalPromptBytes), + "mimeType": fileInfo.get("mimeType", "text/markdown") if fileInfo else "text/markdown" + } + firstMessageDocs.append(doc) + logger.debug("Fast path: Stored original user prompt as document") + except Exception as e: + logger.warning(f"Fast path: Failed to store original prompt as document: {e}") + + # Process user-uploaded documents (fileIds) + if userInput.listFileId: + try: + userDocs = await self._processFileIds(userInput.listFileId, None) + if userDocs: + firstMessageDocs.extend(userDocs) + except Exception as e: + logger.warning(f"Fast path: Failed to process user fileIds: {e}") + + self.services.chat.storeMessageWithDocuments(workflow, firstMessageData, firstMessageDocs) # Get user language if available userLanguage = getattr(self.services, 'currentUserLanguage', None) @@ -587,7 +606,7 @@ The following is the user's original input message. Analyze intent, normalize th "actionProgress": "pending" } - # Analyze the user's input to detect language, normalize request, extract intent, and offload bulky context into documents + # Analyze the user's input to detect language, normalize request, and extract intent # SKIP user intention analysis if already done in combined analysis (skipIntentionAnalysis=True) # or for AUTOMATION mode - it uses predefined JSON plans createdDocs = [] @@ -600,61 +619,49 @@ The following is the user's original input message. Analyze intent, normalize th detectedLanguage = getattr(self.services, 'currentUserLanguage', None) normalizedRequest = getattr(self.services, 'currentUserPromptNormalized', None) or userInput.prompt intentText = getattr(self.services, 'currentUserPrompt', None) or userInput.prompt - contextItems = getattr(self.services, 'currentUserContextItems', None) or [] workflowIntent = getattr(workflow, '_workflowIntent', None) - # Create documents for context items (if available from combined analysis) - if contextItems and isinstance(contextItems, list): - for idx, item in enumerate(contextItems): - try: - title = item.get('title') if isinstance(item, dict) else None - mime = item.get('mimeType') if isinstance(item, dict) else None - content = item.get('content') if isinstance(item, dict) else None - if not content: - continue - fileName = (title or f"user_context_{idx+1}.txt").strip() - mimeType = (mime or "text/plain").strip() - - # Neutralize content before storing if neutralization is enabled - contentBytes = content.encode('utf-8') - contentBytes = await self._neutralizeContentIfEnabled(contentBytes, mimeType) - - # Create file in component storage - fileItem = self.services.interfaceDbComponent.createFile( - name=fileName, - mimeType=mimeType, - content=contentBytes - ) - # Persist file data - self.services.interfaceDbComponent.createFileData(fileItem.id, contentBytes) - - # Collect file info - fileInfo = self.services.chat.getFileInfo(fileItem.id) - doc = ChatDocument( - fileId=fileItem.id, - fileName=fileInfo.get("fileName", fileName) if fileInfo else fileName, - fileSize=fileInfo.get("size", len(contentBytes)) if fileInfo else len(contentBytes), - mimeType=fileInfo.get("mimeType", mimeType) if fileInfo else mimeType - ) - createdDocs.append(doc) - except Exception: - continue + # Use normalizedRequest as message, attach original prompt as document + if normalizedRequest and normalizedRequest != userInput.prompt: + messageData["message"] = normalizedRequest + logger.debug(f"Using normalized request as message (length: {len(normalizedRequest)})") + + # Store original user prompt as .md document + if userInput.prompt: + try: + originalPromptBytes = userInput.prompt.encode('utf-8') + originalPromptBytes = await self._neutralizeContentIfEnabled(originalPromptBytes, "text/markdown") + fileItem = self.services.interfaceDbComponent.createFile( + name="user_prompt_original.md", + mimeType="text/markdown", + content=originalPromptBytes + ) + self.services.interfaceDbComponent.createFileData(fileItem.id, originalPromptBytes) + fileInfo = self.services.chat.getFileInfo(fileItem.id) + doc = { + "fileId": fileItem.id, + "fileName": fileInfo.get("fileName", "user_prompt_original.md") if fileInfo else "user_prompt_original.md", + "fileSize": fileInfo.get("size", len(originalPromptBytes)) if fileInfo else len(originalPromptBytes), + "mimeType": fileInfo.get("mimeType", "text/markdown") if fileInfo else "text/markdown" + } + createdDocs.append(doc) + logger.debug("Stored original user prompt as document") + except Exception as e: + logger.warning(f"Failed to store original prompt as document: {e}") else: try: analyzerPrompt = ( "You are an input analyzer. From the user's message, perform ALL of the following in one pass:\n" "1) detectedLanguage: detect ISO 639-1 language code (e.g., de, en).\n" - "2) normalizedRequest: full, explicit restatement of the user's request in the detected language; do NOT summarize; preserve ALL constraints and details.\n" + "2) normalizedRequest: full, explicit restatement of the user's request in the detected language; do NOT summarize; preserve ALL constraints and details. Include all data and context from the original message.\n" "3) intent: concise single-paragraph core request in the detected language for high-level routing.\n" - "4) contextItems: supportive data blocks to attach as separate documents if significantly larger than the intent (large literal content, long lists/tables, code/JSON blocks, transcripts, CSV fragments, detailed specs). Keep URLs in the intent unless they embed large pasted content.\n" - "5) dataType: What type of data/content they want (numbers|text|documents|analysis|code|unknown).\n" - "6) expectedFormats: What file format(s) they expect - provide matching file format extensions list (e.g., [\"xlsx\", \"pdf\"]). If format is unclear or not specified, use empty list [].\n" - "7) qualityRequirements: Quality requirements they have (accuracy, completeness) as {accuracyThreshold: 0.0-1.0, completenessThreshold: 0.0-1.0}.\n" - "8) successCriteria: Specific success criteria that define completion (array of strings).\n" - "9) needsWorkflowHistory: Boolean indicating if this request needs previous workflow rounds/history to be understood or completed (e.g., 'continue', 'retry', 'fix', 'improve', 'update', 'modify', 'based on previous', 'build on', references to earlier work). Return true if the request is a continuation, retry, modification, or builds upon previous work.\n\n" + "4) dataType: What type of data/content they want (numbers|text|documents|analysis|code|unknown).\n" + "5) expectedFormats: What file format(s) they expect - provide matching file format extensions list (e.g., [\"xlsx\", \"pdf\"]). If format is unclear or not specified, use empty list [].\n" + "6) qualityRequirements: Quality requirements they have (accuracy, completeness) as {accuracyThreshold: 0.0-1.0, completenessThreshold: 0.0-1.0}.\n" + "7) successCriteria: Specific success criteria that define completion (array of strings).\n" + "8) needsWorkflowHistory: Boolean indicating if this request needs previous workflow rounds/history to be understood or completed (e.g., 'continue', 'retry', 'fix', 'improve', 'update', 'modify', 'based on previous', 'build on', references to earlier work). Return true if the request is a continuation, retry, modification, or builds upon previous work.\n\n" "Rules:\n" - "- If total content (intent + data) is < 10% of model max tokens, do not extract; return empty contextItems and keep intent compact and self-contained.\n" - "- If content exceeds that threshold, move bulky parts into contextItems; keep intent short and clear.\n" + "- normalizedRequest must contain the COMPLETE restatement including all data references - do NOT strip or extract content.\n" "- Preserve critical references (URLs, filenames) in intent.\n" "- Normalize to the primary detected language if mixed-language.\n\n" "Return ONLY JSON (no markdown) with this shape:\n" @@ -662,13 +669,6 @@ The following is the user's original input message. Analyze intent, normalize th " \"detectedLanguage\": \"de|en|fr|it|...\",\n" " \"normalizedRequest\": \"Full explicit instruction in detected language\",\n" " \"intent\": \"Concise normalized request...\",\n" - " \"contextItems\": [\n" - " {\n" - " \"title\": \"User context 1\",\n" - " \"mimeType\": \"text/plain\",\n" - " \"content\": \"Full extracted content block here\"\n" - " }\n" - " ],\n" " \"dataType\": \"numbers|text|documents|analysis|code|unknown\",\n" " \"expectedFormats\": [\"pdf\", \"docx\", \"xlsx\", \"txt\", \"json\", \"csv\", \"html\", \"md\"],\n" " \"qualityRequirements\": {\n" @@ -679,7 +679,7 @@ The following is the user's original input message. Analyze intent, normalize th " \"needsWorkflowHistory\": true|false\n" "}\n\n" "## User Message\n" - "The following is the user's original input message. Extract intent, normalize the request, and identify any large context blocks that should be moved to separate documents:\n\n" + "The following is the user's original input message. Analyze intent, normalize the request, and determine complexity:\n\n" "################ USER INPUT START #################\n" f"{userInput.prompt.replace('{', '{{').replace('}', '}}') if userInput.prompt else ''}\n" "################ USER INPUT FINISH #################" @@ -695,7 +695,6 @@ The following is the user's original input message. Analyze intent, normalize th detectedLanguage = None normalizedRequest = None intentText = userInput.prompt - contextItems = [] workflowIntent = None # Parse analyzer response (JSON expected) @@ -706,14 +705,11 @@ The following is the user's original input message. Analyze intent, normalize th parsed = json.loads(aiResponse[jsonStart:jsonEnd]) detectedLanguage = parsed.get('detectedLanguage') or None normalizedRequest = parsed.get('normalizedRequest') or None - if parsed.get('intent'): - intentText = parsed.get('intent') - contextItems = parsed.get('contextItems') or [] # Extract intent analysis fields and store as workflowIntent intentText = parsed.get('intent') or userInput.prompt workflowIntent = { - 'intent': intentText, # Use intent instead of primaryGoal + 'intent': intentText, 'dataType': parsed.get('dataType', 'unknown'), 'expectedFormats': parsed.get('expectedFormats', []), 'qualityRequirements': parsed.get('qualityRequirements', {}), @@ -724,32 +720,23 @@ The following is the user's original input message. Analyze intent, normalize th # Store needsWorkflowHistory in services for fast path decision needsHistoryFromIntention = parsed.get('needsWorkflowHistory', False) - # Always set the value - default to False if not a boolean setattr(self.services, '_needsWorkflowHistory', bool(needsHistoryFromIntention) if isinstance(needsHistoryFromIntention, bool) else False) # Store workflowIntent in workflow object for reuse if hasattr(self.services, 'workflow') and self.services.workflow: self.services.workflow._workflowIntent = workflowIntent except Exception: - contextItems = [] workflowIntent = None - # Ensure needsWorkflowHistory is False if parsing fails setattr(self.services, '_needsWorkflowHistory', False) - # Update services state - # CRITICAL: Validate language from AI response - # If AI didn't return language or invalid → use user language - # If user language not set → use "en" + # Validate language from AI response validatedLanguage = None - # Validate AI-detected language if detectedLanguage and isinstance(detectedLanguage, str): detectedLanguage = detectedLanguage.strip().lower() - # Check if it's a valid 2-character ISO code if len(detectedLanguage) == 2 and detectedLanguage.isalpha(): validatedLanguage = detectedLanguage - # If AI didn't return valid language, use user language if not validatedLanguage: userLanguage = getattr(self.services.user, 'language', None) if hasattr(self.services, 'user') and self.services.user else None if userLanguage and isinstance(userLanguage, str): @@ -757,12 +744,10 @@ The following is the user's original input message. Analyze intent, normalize th if len(userLanguage) == 2 and userLanguage.isalpha(): validatedLanguage = userLanguage - # Final fallback to "en" if not validatedLanguage: validatedLanguage = "en" logger.warning("Language not detected from AI and user language not set - using default 'en'") - # Set validated language self._setUserLanguage(validatedLanguage) try: setattr(self.services, 'currentUserLanguage', validatedLanguage) @@ -770,60 +755,40 @@ The following is the user's original input message. Analyze intent, normalize th except Exception: pass self.services.currentUserPrompt = intentText or userInput.prompt - # Always set currentUserPromptNormalized - use normalizedRequest if available, otherwise fallback to currentUserPrompt - # CRITICAL: normalizedRequest MUST be used if available, do NOT fall back to intent if normalizedRequest and normalizedRequest.strip(): - # Use normalizedRequest if available and not empty self.services.currentUserPromptNormalized = normalizedRequest logger.debug(f"Stored normalized request from analysis (length: {len(normalizedRequest)})") else: - # Fallback only if normalizedRequest is None or empty - logger.warning(f"normalizedRequest is None or empty in analysis, falling back to intentText. normalizedRequest={normalizedRequest}, intentText={intentText}") + logger.warning(f"normalizedRequest is None or empty in analysis, falling back to intentText") self.services.currentUserPromptNormalized = intentText or userInput.prompt - if contextItems is not None: - self.services.currentUserContextItems = contextItems - # Update message with normalized request if analysis produced one + # Use normalizedRequest as the chat message (transformed user input) if normalizedRequest and normalizedRequest != userInput.prompt: messageData["message"] = normalizedRequest logger.debug(f"Updated first message with normalized request (length: {len(normalizedRequest)})") - # Create documents for context items - if contextItems and isinstance(contextItems, list): - for idx, item in enumerate(contextItems): - try: - title = item.get('title') if isinstance(item, dict) else None - mime = item.get('mimeType') if isinstance(item, dict) else None - content = item.get('content') if isinstance(item, dict) else None - if not content: - continue - fileName = (title or f"user_context_{idx+1}.txt").strip() - mimeType = (mime or "text/plain").strip() - - # Neutralize content before storing if neutralization is enabled - contentBytes = content.encode('utf-8') - contentBytes = await self._neutralizeContentIfEnabled(contentBytes, mimeType) - - # Create file in component storage - fileItem = self.services.interfaceDbComponent.createFile( - name=fileName, - mimeType=mimeType, - content=contentBytes - ) - # Persist file data - self.services.interfaceDbComponent.createFileData(fileItem.id, contentBytes) - - # Collect file info - fileInfo = self.services.chat.getFileInfo(fileItem.id) - doc = ChatDocument( - fileId=fileItem.id, - fileName=fileInfo.get("fileName", fileName) if fileInfo else fileName, - fileSize=fileInfo.get("size", len(contentBytes)) if fileInfo else len(contentBytes), - mimeType=fileInfo.get("mimeType", mimeType) if fileInfo else mimeType - ) - createdDocs.append(doc) - except Exception: - continue + # Store original user prompt as .md document + if userInput.prompt: + try: + originalPromptBytes = userInput.prompt.encode('utf-8') + originalPromptBytes = await self._neutralizeContentIfEnabled(originalPromptBytes, "text/markdown") + fileItem = self.services.interfaceDbComponent.createFile( + name="user_prompt_original.md", + mimeType="text/markdown", + content=originalPromptBytes + ) + self.services.interfaceDbComponent.createFileData(fileItem.id, originalPromptBytes) + fileInfo = self.services.chat.getFileInfo(fileItem.id) + doc = { + "fileId": fileItem.id, + "fileName": fileInfo.get("fileName", "user_prompt_original.md") if fileInfo else "user_prompt_original.md", + "fileSize": fileInfo.get("size", len(originalPromptBytes)) if fileInfo else len(originalPromptBytes), + "mimeType": fileInfo.get("mimeType", "text/markdown") if fileInfo else "text/markdown" + } + createdDocs.append(doc) + logger.debug("Stored original user prompt as document") + except Exception as e: + logger.warning(f"Failed to store original prompt as document: {e}") except Exception as e: logger.warning(f"Prompt analysis failed or skipped: {str(e)}") From 7532841d9d00d509e2998c8cffac6bec3933e461 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Mon, 9 Feb 2026 12:57:47 +0100 Subject: [PATCH 16/18] fixed mandate flag --- modules/datamodels/datamodelUam.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py index d8e7906a..6b7cdc06 100644 --- a/modules/datamodels/datamodelUam.py +++ b/modules/datamodels/datamodelUam.py @@ -89,6 +89,14 @@ class Mandate(BaseModel): json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": True, "frontend_required": False} ) + @field_validator('isSystem', mode='before') + @classmethod + def _coerceIsSystem(cls, v): + """Coerce None to False (for existing DB records without isSystem field).""" + if v is None: + return False + return v + registerModelLabels( "Mandate", From d98c31a4d1bc3adec78328f8df1fa8fde23522dc Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Mon, 9 Feb 2026 23:44:52 +0100 Subject: [PATCH 17/18] logical fixes --- app.py | 4 +- modules/datamodels/datamodelFiles.py | 3 +- modules/datamodels/datamodelUtils.py | 13 +- .../automation/interfaceFeatureAutomation.py | 109 ++++--- .../automation/routeFeatureAutomation.py | 28 +- .../chatbot/interfaceFeatureChatbot.py | 4 +- modules/features/chatbot/service.py | 8 +- .../chatplayground/mainChatplayground.py | 8 +- .../realEstate/interfaceFeatureRealEstate.py | 4 +- .../trustee/interfaceFeatureTrustee.py | 10 +- modules/interfaces/interfaceBootstrap.py | 2 +- modules/interfaces/interfaceDbApp.py | 5 +- modules/interfaces/interfaceDbChat.py | 77 ++--- modules/interfaces/interfaceDbManagement.py | 296 +++++++++++------- modules/interfaces/interfaceRbac.py | 2 +- modules/routes/routeAdminAutomationEvents.py | 2 +- modules/routes/routeAdminRbacRules.py | 9 +- modules/routes/routeDataPrompts.py | 30 +- modules/security/rbac.py | 7 +- modules/services/__init__.py | 13 +- .../services/serviceAi/subStructureFilling.py | 4 +- .../mainServiceGeneration.py | 6 +- .../serviceGeneration/paths/codePath.py | 4 +- .../serviceGeneration/renderers/registry.py | 163 ++++++---- .../renderers/rendererCsv.py | 83 +++-- modules/shared/callbackRegistry.py | 15 +- modules/workflows/automation/mainWorkflow.py | 105 +++---- .../automation/subAutomationSchedule.py | 11 +- modules/workflows/methods/methodBase.py | 12 +- 29 files changed, 588 insertions(+), 449 deletions(-) diff --git a/app.py b/app.py index 32eb31f6..06ee8c2d 100644 --- a/app.py +++ b/app.py @@ -315,7 +315,7 @@ async def lifespan(app: FastAPI): logger.warning(f"Could not initialize feature containers: {e}") # --- Init Managers --- - await subAutomationSchedule.start(eventUser) # Automation scheduler + subAutomationSchedule.start(eventUser) # Automation scheduler eventManager.start() # Register audit log cleanup scheduler @@ -345,7 +345,7 @@ async def lifespan(app: FastAPI): # --- Stop Managers --- eventManager.stop() - await subAutomationSchedule.stop(eventUser) # Automation scheduler + subAutomationSchedule.stop(eventUser) # Automation scheduler # --- Stop Feature Containers (Plug&Play) --- try: diff --git a/modules/datamodels/datamodelFiles.py b/modules/datamodels/datamodelFiles.py index f1b07eb3..588097e4 100644 --- a/modules/datamodels/datamodelFiles.py +++ b/modules/datamodels/datamodelFiles.py @@ -3,7 +3,7 @@ """File-related datamodels: FileItem, FilePreview, FileData.""" from typing import Dict, Any, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from modules.shared.attributeUtils import registerModelLabels from modules.shared.timeUtils import getUtcTimestamp import uuid @@ -11,6 +11,7 @@ import base64 class FileItem(BaseModel): + model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.) id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) mandateId: Optional[str] = Field(default="", description="ID of the mandate this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) featureInstanceId: Optional[str] = Field(default="", description="ID of the feature instance this file belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) diff --git a/modules/datamodels/datamodelUtils.py b/modules/datamodels/datamodelUtils.py index 1ac9ad33..614d6592 100644 --- a/modules/datamodels/datamodelUtils.py +++ b/modules/datamodels/datamodelUtils.py @@ -3,22 +3,33 @@ """Utility datamodels: Prompt, TextMultilingual.""" from typing import Dict, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from modules.shared.attributeUtils import registerModelLabels import uuid class Prompt(BaseModel): + model_config = ConfigDict(extra='allow') # Preserve system fields (_createdBy, _createdAt, etc.) id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) mandateId: str = Field(default="", description="ID of the mandate this prompt belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) + isSystem: bool = Field(default=False, description="System prompt visible to all users (read-only for non-SysAdmin)", json_schema_extra={"frontend_type": "boolean", "frontend_readonly": True, "frontend_required": False}) content: str = Field(description="Content of the prompt", json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": True}) name: str = Field(description="Name of the prompt", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True}) + + @field_validator('isSystem', mode='before') + @classmethod + def _coerceIsSystem(cls, v): + """Existing records may have isSystem=None (field didn't exist). Treat None as False.""" + if v is None: + return False + return v registerModelLabels( "Prompt", {"en": "Prompt", "fr": "Invite"}, { "id": {"en": "ID", "fr": "ID"}, "mandateId": {"en": "Mandate ID", "fr": "ID du mandat"}, + "isSystem": {"en": "System", "fr": "Système"}, "content": {"en": "Content", "fr": "Contenu"}, "name": {"en": "Name", "fr": "Nom"}, }, diff --git a/modules/features/automation/interfaceFeatureAutomation.py b/modules/features/automation/interfaceFeatureAutomation.py index e99f7683..3a7cba08 100644 --- a/modules/features/automation/interfaceFeatureAutomation.py +++ b/modules/features/automation/interfaceFeatureAutomation.py @@ -8,7 +8,6 @@ Uses the PostgreSQL connector for data access with user/mandate filtering. import logging import uuid import math -import asyncio from typing import Dict, Any, List, Optional, Union from modules.security.rbac import RbacClass @@ -99,7 +98,7 @@ class AutomationObjects: return True elif accessLevel == AccessLevel.MY: if recordId: - record = self.db.getRecordset(model, {"id": recordId}) + record = self.db.getRecordset(model, recordFilter={"id": recordId}) if record: return record[0].get("_createdBy") == self.userId else: @@ -118,16 +117,17 @@ class AutomationObjects: def _enrichAutomationsWithUserAndMandate(self, automations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ - Batch enrich automations with user names and mandate names for display. + Batch enrich automations with user names, mandate names and feature instance labels. Uses direct DB lookup (no RBAC) because this is purely cosmetic enrichment — the user already has RBAC-verified access to the automations themselves. """ if not automations: return automations - # Collect all unique user IDs and mandate IDs + # Collect all unique IDs userIds = set() mandateIds = set() + featureInstanceIds = set() for automation in automations: createdBy = automation.get("_createdBy") @@ -137,48 +137,63 @@ class AutomationObjects: mandateId = automation.get("mandateId") if mandateId: mandateIds.add(mandateId) + + featureInstanceId = automation.get("featureInstanceId") + if featureInstanceId: + featureInstanceIds.add(featureInstanceId) # Use root DB connector for display-only lookups (no RBAC needed) + usersMap = {} + mandatesMap = {} + featureInstancesMap = {} try: from modules.datamodels.datamodelUam import UserInDB, Mandate + from modules.datamodels.datamodelFeatures import FeatureInstance from modules.security.rootAccess import getRootDbAppConnector dbAppConn = getRootDbAppConnector() # Batch fetch user display names - usersMap = {} if userIds: for userId in userIds: - users = dbAppConn.getRecordset(UserInDB, {"id": userId}) + users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId}) if users: user = users[0] - fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip() - usersMap[userId] = fullName or user.get("email") or user.get("username") or userId + displayName = user.get("fullName") or user.get("username") or user.get("email") or None + if displayName: + usersMap[userId] = displayName # Batch fetch mandate display names - mandatesMap = {} if mandateIds: for mandateId in mandateIds: - mandates = dbAppConn.getRecordset(Mandate, {"id": mandateId}) + mandates = dbAppConn.getRecordset(Mandate, recordFilter={"id": mandateId}) if mandates: - mandatesMap[mandateId] = mandates[0].get("name") or mandateId + label = mandates[0].get("label") or mandates[0].get("name") or None + if label: + mandatesMap[mandateId] = label + + # Batch fetch feature instance labels + if featureInstanceIds: + for fiId in featureInstanceIds: + instances = dbAppConn.getRecordset(FeatureInstance, recordFilter={"id": fiId}) + if instances: + fi = instances[0] + label = fi.get("label") or fi.get("featureCode") or None + if label: + featureInstancesMap[fiId] = label except Exception as e: - logger.warning(f"Could not enrich automations with user/mandate names: {e}") - usersMap = {} - mandatesMap = {} + logger.warning(f"Could not enrich automations with display names: {e}") # Enrich each automation with the fetched data + # SECURITY: Never show a fallback name — if lookup fails, show empty string for automation in automations: createdBy = automation.get("_createdBy") - if createdBy: - automation["_createdByUserName"] = usersMap.get(createdBy, createdBy) - else: - automation["_createdByUserName"] = "-" + automation["_createdByUserName"] = usersMap.get(createdBy, "") if createdBy else "" mandateId = automation.get("mandateId") - if mandateId: - automation["mandateName"] = mandatesMap.get(mandateId, mandateId) - else: - automation["mandateName"] = "-" + automation["mandateName"] = mandatesMap.get(mandateId, "") if mandateId else "" + + featureInstanceId = automation.get("featureInstanceId") + automation["featureInstanceName"] = featureInstancesMap.get(featureInstanceId, "") if featureInstanceId else "" return automations @@ -195,11 +210,13 @@ class AutomationObjects: Supports optional pagination, sorting, and filtering. Computes status field for each automation. """ - # Use RBAC filtering + # AutomationDefinitions can belong to any feature instance within a mandate. + # Filter by mandateId only — not by featureInstanceId — to show all definitions across features. filteredAutomations = getRecordsetWithRBAC( self.db, AutomationDefinition, - self.currentUser + self.currentUser, + mandateId=self.mandateId ) # Compute status for each automation and normalize executionLogs @@ -282,12 +299,14 @@ class AutomationObjects: If False (default), returns Pydantic model without system fields. """ try: - # Use RBAC filtering + # AutomationDefinitions can belong to any feature instance within a mandate. + # Filter by mandateId only — not by featureInstanceId. filtered = getRecordsetWithRBAC( self.db, AutomationDefinition, self.currentUser, - recordFilter={"id": automationId} + recordFilter={"id": automationId}, + mandateId=self.mandateId ) if not filtered: @@ -363,8 +382,8 @@ class AutomationObjects: if createdAutomation.get("executionLogs") is None: createdAutomation["executionLogs"] = [] - # Trigger automation change callback (async, don't wait) - asyncio.create_task(self._notifyAutomationChanged()) + # Trigger automation change callback + self._notifyAutomationChanged() # Clean metadata fields and return Pydantic model cleanedRecord = {k: v for k, v in createdAutomation.items() if not k.startswith("_")} @@ -408,8 +427,8 @@ class AutomationObjects: if updatedAutomation.get("executionLogs") is None: updatedAutomation["executionLogs"] = [] - # Trigger automation change callback (async, don't wait) - asyncio.create_task(self._notifyAutomationChanged()) + # Trigger automation change callback + self._notifyAutomationChanged() # Clean metadata fields and return Pydantic model cleanedRecord = {k: v for k, v in updatedAutomation.items() if not k.startswith("_")} @@ -432,8 +451,8 @@ class AutomationObjects: # Delete automation from database self.db.recordDelete(AutomationDefinition, automationId) - # Trigger automation change callback (async, don't wait) - asyncio.create_task(self._notifyAutomationChanged()) + # Trigger automation change callback + self._notifyAutomationChanged() return True except Exception as e: @@ -454,7 +473,9 @@ class AutomationObjects: return getRecordsetWithRBAC( self.db, AutomationDefinition, - user + user, + mandateId=self.mandateId, + featureInstanceId=self.featureInstanceId ) # ========================================================================= @@ -466,7 +487,7 @@ class AutomationObjects: Returns automation templates filtered by RBAC (MY = own templates). Supports optional pagination, sorting, and filtering. """ - # Use RBAC filtering + # Templates are global (not mandate/feature-instance scoped) — no mandateId/featureInstanceId filter filteredTemplates = getRecordsetWithRBAC( self.db, AutomationTemplate, @@ -526,23 +547,24 @@ class AutomationObjects: userNameMap = {} for userId in userIds: - users = dbAppConn.getRecordset(UserInDB, {"id": userId}) + users = dbAppConn.getRecordset(UserInDB, recordFilter={"id": userId}) if users: user = users[0] - fullName = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip() - userNameMap[userId] = fullName or user.get("email", "Unknown") + displayName = user.get("fullName") or user.get("username") or user.get("email") or None + if displayName: + userNameMap[userId] = displayName - # Apply to templates + # Apply to templates — SECURITY: no fallback, empty if not found for template in templates: createdBy = template.get("_createdBy") - if createdBy and createdBy in userNameMap: - template["_createdByUserName"] = userNameMap[createdBy] + template["_createdByUserName"] = userNameMap.get(createdBy, "") if createdBy else "" except Exception as e: logger.warning(f"Could not enrich templates with user names: {e}") def getAutomationTemplate(self, templateId: str) -> Optional[Dict[str, Any]]: """Returns an automation template by ID if user has access.""" try: + # Templates are global — no mandateId/featureInstanceId filter filtered = getRecordsetWithRBAC( self.db, AutomationTemplate, @@ -645,12 +667,13 @@ class AutomationObjects: logger.error(f"Error deleting automation template: {str(e)}") raise - async def _notifyAutomationChanged(self): - """Notify registered callbacks about automation changes (decoupled from features).""" + def _notifyAutomationChanged(self): + """Notify registered callbacks about automation changes (decoupled from features). + Sync-safe: works from both sync and async contexts.""" try: from modules.shared.callbackRegistry import callbackRegistry # Trigger callbacks without knowing which features are listening - await callbackRegistry.trigger('automation.changed', self) + callbackRegistry.trigger('automation.changed', self) except Exception as e: logger.error(f"Error notifying automation change: {str(e)}") diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py index f7c5feda..d6845a3e 100644 --- a/modules/features/automation/routeFeatureAutomation.py +++ b/modules/features/automation/routeFeatureAutomation.py @@ -66,7 +66,9 @@ def get_automations( detail=f"Invalid pagination parameter: {str(e)}" ) - chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) + # AutomationDefinitions can belong to ANY feature instance within a mandate. + # The list endpoint must show all definitions for the user's mandate, not filter by a specific featureInstanceId. + chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) result = chatInterface.getAllAutomationDefinitions(pagination=paginationParams) # If pagination was requested, result is PaginatedResult @@ -150,7 +152,7 @@ def get_available_actions( # Ensure methods are discovered (need a service center for discovery) if not methods: # Create a lightweight service center for method discovery - services = getServices(context.user, context.mandateId) + services = getServices(context.user, mandateId=context.mandateId) discoverMethods(services) actionsList = [] @@ -235,7 +237,7 @@ def get_automation( ) -> AutomationDefinition: """Get a single automation definition by ID""" try: - chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) + chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) automation = chatInterface.getAutomationDefinition(automationId) if not automation: raise HTTPException( @@ -263,7 +265,7 @@ def update_automation( ) -> AutomationDefinition: """Update an automation definition""" try: - chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) + chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) automationData = automation.model_dump() updated = chatInterface.updateAutomationDefinition(automationId, automationData) return updated @@ -291,7 +293,7 @@ def update_automation_status( ) -> AutomationDefinition: """Update only the active status of an automation definition""" try: - chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) + chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) # Get existing automation automation = chatInterface.getAutomationDefinition(automationId) @@ -331,7 +333,7 @@ def delete_automation( ) -> Response: """Delete an automation definition""" try: - chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None, featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None) + chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) success = chatInterface.deleteAutomationDefinition(automationId) if success: return Response(status_code=204) @@ -364,13 +366,15 @@ async def execute_automation_route( """Execute an automation immediately (test mode)""" try: from modules.services import getInterface as getServices - services = getServices(context.user, context.mandateId) - # Propagate feature context for billing - if context.featureInstanceId: - services.featureInstanceId = str(context.featureInstanceId) - services.featureCode = 'automation' + services = getServices(context.user, mandateId=context.mandateId, featureInstanceId=context.featureInstanceId) + + # Load automation with current user's context (user has RBAC permissions via UI) + automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True) + if not automation: + raise ValueError(f"Automation {automationId} not found") + from modules.workflows.automation import executeAutomation - workflow = await executeAutomation(automationId, services) + workflow = await executeAutomation(automationId, automation, context.user, services) return workflow except HTTPException: raise diff --git a/modules/features/chatbot/interfaceFeatureChatbot.py b/modules/features/chatbot/interfaceFeatureChatbot.py index 4d77f633..edfc0bc2 100644 --- a/modules/features/chatbot/interfaceFeatureChatbot.py +++ b/modules/features/chatbot/interfaceFeatureChatbot.py @@ -360,10 +360,12 @@ class ChatObjects: return False tableName = modelClass.__name__ + from modules.interfaces.interfaceRbac import buildDataObjectKey + objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None) permissions = self.rbac.getUserPermissions( self.currentUser, AccessRuleContext.DATA, - tableName, + objectKey, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId ) diff --git a/modules/features/chatbot/service.py b/modules/features/chatbot/service.py index 3afd1632..d6a8d1a4 100644 --- a/modules/features/chatbot/service.py +++ b/modules/features/chatbot/service.py @@ -91,12 +91,8 @@ async def chatProcess( ChatWorkflow instance """ try: - # Get services with mandate context - services = getServices(currentUser, mandateId) - - # Set feature context for billing - if featureInstanceId: - services.featureInstanceId = featureInstanceId + # Get services with mandate and feature instance context + services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId) services.featureCode = 'chatbot' interfaceDbChat = services.interfaceDbChat diff --git a/modules/features/chatplayground/mainChatplayground.py b/modules/features/chatplayground/mainChatplayground.py index 085d93e4..246236a1 100644 --- a/modules/features/chatplayground/mainChatplayground.py +++ b/modules/features/chatplayground/mainChatplayground.py @@ -49,10 +49,10 @@ RESOURCE_OBJECTS = [ ] # Template roles for this feature -# IMPORTANT: "viewer" role is required for automatic user assignment! +# Role names MUST follow convention: {featureCode}-{roleName} TEMPLATE_ROLES = [ { - "roleLabel": "viewer", + "roleLabel": "chatplayground-viewer", "description": { "en": "Chat Playground Viewer - View chat playground (read-only)", "de": "Chat Playground Betrachter - Chat Playground ansehen (nur lesen)", @@ -67,7 +67,7 @@ TEMPLATE_ROLES = [ ] }, { - "roleLabel": "user", + "roleLabel": "chatplayground-user", "description": { "en": "Chat Playground User - Use chat playground and workflows", "de": "Chat Playground Benutzer - Chat Playground und Workflows nutzen", @@ -86,7 +86,7 @@ TEMPLATE_ROLES = [ ] }, { - "roleLabel": "admin", + "roleLabel": "chatplayground-admin", "description": { "en": "Chat Playground Admin - Full access to chat playground", "de": "Chat Playground Admin - Vollzugriff auf Chat Playground", diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py index 86374a2c..57e14f23 100644 --- a/modules/features/realEstate/interfaceFeatureRealEstate.py +++ b/modules/features/realEstate/interfaceFeatureRealEstate.py @@ -749,10 +749,12 @@ class RealEstateObjects: return False tableName = modelClass.__name__ + from modules.interfaces.interfaceRbac import buildDataObjectKey + objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None) permissions = self.rbac.getUserPermissions( self.currentUser, AccessRuleContext.DATA, - tableName, + objectKey, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId ) diff --git a/modules/features/trustee/interfaceFeatureTrustee.py b/modules/features/trustee/interfaceFeatureTrustee.py index 8710d148..7b3e4a6b 100644 --- a/modules/features/trustee/interfaceFeatureTrustee.py +++ b/modules/features/trustee/interfaceFeatureTrustee.py @@ -171,10 +171,12 @@ class TrusteeObjects: return False tableName = modelClass.__name__ + from modules.interfaces.interfaceRbac import buildDataObjectKey + objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None) permissions = self.rbac.getUserPermissions( self.currentUser, AccessRuleContext.DATA, - tableName, + objectKey, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId ) @@ -198,10 +200,12 @@ class TrusteeObjects: return AccessLevel.NONE tableName = modelClass.__name__ + from modules.interfaces.interfaceRbac import buildDataObjectKey + objectKey = buildDataObjectKey(tableName, featureCode=self.featureCode if hasattr(self, 'featureCode') else None) permissions = self.rbac.getUserPermissions( self.currentUser, AccessRuleContext.DATA, - tableName, + objectKey, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId ) @@ -1470,7 +1474,7 @@ class TrusteeObjects: def getAllUserAccess(self, userId: str) -> List[Dict[str, Any]]: """Get all access records for a user across all organisations.""" - return self.db.getRecordset(TrusteeAccess, {"userId": userId}) + return self.db.getRecordset(TrusteeAccess, recordFilter={"userId": userId}) def getUserTrusteeRoles(self, userId: str, organisationId: str, contractId: Optional[str] = None) -> List[str]: """ diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index f750565d..8a58f352 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -129,7 +129,7 @@ def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] # Get admin user ID if not provided (from poweron_app) if not adminUserId: - adminUsers = dbApp.getRecordset(UserInDB, {"email": APP_CONFIG.ADMIN_EMAIL}) + adminUsers = dbApp.getRecordset(UserInDB, recordFilter={"email": APP_CONFIG.ADMIN_EMAIL}) adminUserId = adminUsers[0]["id"] if adminUsers else None # Update context with admin user if adminUserId: diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 68fee415..2ca8232b 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -245,10 +245,13 @@ class AppObjects: return False tableName = modelClass.__name__ + # Use buildDataObjectKey for semantic namespace lookup + from modules.interfaces.interfaceRbac import buildDataObjectKey + objectKey = buildDataObjectKey(tableName) permissions = self.rbac.getUserPermissions( self.currentUser, AccessRuleContext.DATA, - tableName, + objectKey, mandateId=self.mandateId ) diff --git a/modules/interfaces/interfaceDbChat.py b/modules/interfaces/interfaceDbChat.py index e7925dbd..dfdefe3c 100644 --- a/modules/interfaces/interfaceDbChat.py +++ b/modules/interfaces/interfaceDbChat.py @@ -339,6 +339,18 @@ class ChatObjects: pass + def _getRecordset(self, modelClass, recordFilter=None, **kwargs): + """Wrapper for getRecordsetWithRBAC that automatically includes mandateId/featureInstanceId.""" + return getRecordsetWithRBAC( + self.db, + modelClass, + self.currentUser, + recordFilter=recordFilter, + mandateId=self.mandateId, + featureInstanceId=self.featureInstanceId, + **kwargs + ) + def checkRbacPermission( self, modelClass: type, @@ -610,12 +622,7 @@ class ChatObjects: If pagination is provided: PaginatedResult with items and metadata """ # Use RBAC filtering with featureInstanceId for instance-level isolation - filteredWorkflows = getRecordsetWithRBAC(self.db, - ChatWorkflow, - self.currentUser, - mandateId=self.mandateId, - featureInstanceId=self.featureInstanceId - ) + filteredWorkflows = self._getRecordset(ChatWorkflow) # If no pagination requested, return all items (no sorting - frontend handles it) if pagination is None: @@ -647,13 +654,7 @@ class ChatObjects: def getWorkflow(self, workflowId: str) -> Optional[ChatWorkflow]: """Returns a workflow by ID if user has access.""" # Use RBAC filtering with featureInstanceId for instance-level isolation - workflows = getRecordsetWithRBAC(self.db, - ChatWorkflow, - self.currentUser, - recordFilter={"id": workflowId}, - mandateId=self.mandateId, - featureInstanceId=self.featureInstanceId - ) + workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) if not workflows: return None @@ -809,7 +810,7 @@ class ChatObjects: # Delete message documents (but NOT the files!) # Note: ChatStat does NOT have messageId - stats are only at workflow level try: - existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) + existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId}) for doc in existing_docs: self.db.recordDelete(ChatDocument, doc["id"]) except Exception as e: @@ -819,12 +820,12 @@ class ChatObjects: self.db.recordDelete(ChatMessage, messageId) # 2. Delete workflow stats - existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId}) + existing_stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId}) for stat in existing_stats: self.db.recordDelete(ChatStat, stat["id"]) # 3. Delete workflow logs - existing_logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId}) + existing_logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId}) for log in existing_logs: self.db.recordDelete(ChatLog, log["id"]) @@ -855,11 +856,7 @@ class ChatObjects: """ # Check workflow access first (without calling getWorkflow to avoid circular reference) # Use RBAC filtering - workflows = getRecordsetWithRBAC(self.db, - ChatWorkflow, - self.currentUser, - recordFilter={"id": workflowId} - ) + workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) if not workflows: if pagination is None: @@ -867,7 +864,7 @@ class ChatObjects: return PaginatedResult(items=[], totalItems=0, totalPages=0) # Get messages for this workflow from normalized table - messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId}) + messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId}) # Convert raw messages to dict format for sorting/filtering messageDicts = [] @@ -1143,7 +1140,7 @@ class ChatObjects: raise ValueError("messageId cannot be empty") # Check if message exists in database - messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"id": messageId}) + messages = self._getRecordset(ChatMessage, recordFilter={"id": messageId}) if not messages: logger.warning(f"Message with ID {messageId} does not exist in database") @@ -1250,12 +1247,12 @@ class ChatObjects: # CASCADE DELETE: Delete all related data first # 1. Delete message stats - existing_stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"messageId": messageId}) + existing_stats = self._getRecordset(ChatStat, recordFilter={"messageId": messageId}) for stat in existing_stats: self.db.recordDelete(ChatStat, stat["id"]) # 2. Delete message documents (but NOT the files!) - existing_docs = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) + existing_docs = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId}) for doc in existing_docs: self.db.recordDelete(ChatDocument, doc["id"]) @@ -1282,7 +1279,7 @@ class ChatObjects: # Get documents for this message from normalized table - documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) + documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId}) if not documents: logger.warning(f"No documents found for message {messageId}") @@ -1323,7 +1320,7 @@ class ChatObjects: def getDocuments(self, messageId: str) -> List[ChatDocument]: """Returns documents for a message from normalized table.""" try: - documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) + documents = self._getRecordset(ChatDocument, recordFilter={"messageId": messageId}) return [ChatDocument(**doc) for doc in documents] except Exception as e: logger.error(f"Error getting message documents: {str(e)}") @@ -1369,11 +1366,7 @@ class ChatObjects: """ # Check workflow access first (without calling getWorkflow to avoid circular reference) # Use RBAC filtering - workflows = getRecordsetWithRBAC(self.db, - ChatWorkflow, - self.currentUser, - recordFilter={"id": workflowId} - ) + workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) if not workflows: if pagination is None: @@ -1381,7 +1374,7 @@ class ChatObjects: return PaginatedResult(items=[], totalItems=0, totalPages=0) # Get logs for this workflow from normalized table - logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId}) + logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId}) # Convert raw logs to dict format for sorting/filtering logDicts = [] @@ -1513,17 +1506,13 @@ class ChatObjects: """Returns list of statistics for a workflow if user has access.""" # Check workflow access first (without calling getWorkflow to avoid circular reference) # Use RBAC filtering - workflows = getRecordsetWithRBAC(self.db, - ChatWorkflow, - self.currentUser, - recordFilter={"id": workflowId} - ) + workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) if not workflows: return [] # Get stats for this workflow from normalized table - stats = getRecordsetWithRBAC(self.db, ChatStat, self.currentUser, recordFilter={"workflowId": workflowId}) + stats = self._getRecordset(ChatStat, recordFilter={"workflowId": workflowId}) if not stats: return [] @@ -1581,11 +1570,7 @@ class ChatObjects: """ # Check workflow access first # Use RBAC filtering - workflows = getRecordsetWithRBAC(self.db, - ChatWorkflow, - self.currentUser, - recordFilter={"id": workflowId} - ) + workflows = self._getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) if not workflows: return {"items": []} @@ -1594,7 +1579,7 @@ class ChatObjects: items = [] # Get messages - messages = getRecordsetWithRBAC(self.db, ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId}) + messages = self._getRecordset(ChatMessage, recordFilter={"workflowId": workflowId}) for msg in messages: # Apply timestamp filtering in Python msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp()) @@ -1635,7 +1620,7 @@ class ChatObjects: }) # Get logs - return all logs with roundNumber if available - logs = getRecordsetWithRBAC(self.db, ChatLog, self.currentUser, recordFilter={"workflowId": workflowId}) + logs = self._getRecordset(ChatLog, recordFilter={"workflowId": workflowId}) for log in logs: # Apply timestamp filtering in Python logTimestamp = parseTimestamp(log.get("timestamp"), default=getUtcTimestamp()) diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py index 59fc4c1c..b387b34c 100644 --- a/modules/interfaces/interfaceDbManagement.py +++ b/modules/interfaces/interfaceDbManagement.py @@ -313,10 +313,12 @@ class ComponentObjects: return False tableName = modelClass.__name__ + from modules.interfaces.interfaceRbac import buildDataObjectKey + objectKey = buildDataObjectKey(tableName) permissions = self.rbac.getUserPermissions( self.currentUser, AccessRuleContext.DATA, - tableName, + objectKey, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId ) @@ -590,10 +592,58 @@ class ComponentObjects: # Prompt methods + def _isSysAdmin(self) -> bool: + """Check if the current user is a SysAdmin.""" + return hasattr(self.currentUser, 'isSysAdmin') and self.currentUser.isSysAdmin + + def _enrichPromptsWithPermissions(self, prompts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Enrich prompts with row-level _permissions based on ownership and isSystem flag. + + - SysAdmin: canUpdate=True, canDelete=True on all prompts + - Regular user on own prompts: canUpdate=True, canDelete=True + - Regular user on system prompts: canUpdate=False, canDelete=False (read-only) + """ + isSysAdmin = self._isSysAdmin() + for prompt in prompts: + isOwner = prompt.get("_createdBy") == self.userId + prompt["_permissions"] = { + "canUpdate": isOwner or isSysAdmin, + "canDelete": isOwner or isSysAdmin + } + return prompts + + def _getPromptsForUser(self) -> List[Dict[str, Any]]: + """Returns prompts visible to the current user. + + Visibility rules: + - SysAdmin: ALL prompts + - Regular user: own prompts (_createdBy) + system prompts (isSystem=True) + """ + if self._isSysAdmin(): + return self.db.getRecordset(Prompt) + + # Get own prompts + ownPrompts = self.db.getRecordset(Prompt, recordFilter={"_createdBy": self.userId}) + + # Get system prompts + systemPrompts = self.db.getRecordset(Prompt, recordFilter={"isSystem": True}) + + # Merge and deduplicate (a user's own prompt could also be isSystem) + seen = {} + for p in ownPrompts: + seen[p["id"]] = p + for p in systemPrompts: + if p["id"] not in seen: + seen[p["id"]] = p + + return list(seen.values()) + def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]: """ - Returns prompts based on user access level. - Supports optional pagination, sorting, and filtering. + Returns prompts with visibility rules: + - SysAdmin: sees ALL prompts, can CRUD all + - Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own + - Row-level _permissions control edit/delete buttons in the UI Args: pagination: Optional pagination parameters. If None, returns all items. @@ -603,11 +653,11 @@ class ComponentObjects: If pagination is provided: PaginatedResult with items and metadata """ try: - # Use RBAC filtering - filteredPrompts = getRecordsetWithRBAC(self.db, - Prompt, - self.currentUser - ) + # Get prompts based on user role (own + system for regular, all for SysAdmin) + filteredPrompts = self._getPromptsForUser() + + # Enrich with row-level permissions (_permissions: canUpdate, canDelete) + filteredPrompts = self._enrichPromptsWithPermissions(filteredPrompts) # If no pagination requested, return all items if pagination is None: @@ -630,7 +680,7 @@ class ComponentObjects: endIdx = startIdx + pagination.pageSize pagedPrompts = filteredPrompts[startIdx:endIdx] - # Convert to model objects + # Convert to model objects (extra='allow' on Prompt preserves system fields) items = [Prompt(**prompt) for prompt in pagedPrompts] return PaginatedResult( @@ -646,15 +696,24 @@ class ComponentObjects: return PaginatedResult(items=[], totalItems=0, totalPages=0) def getPrompt(self, promptId: str) -> Optional[Prompt]: - """Returns a prompt by ID if user has access.""" - # Use RBAC filtering - filteredPrompts = getRecordsetWithRBAC(self.db, - Prompt, - self.currentUser, - recordFilter={"id": promptId} - ) + """Returns a prompt by ID if the user has visibility. - return Prompt(**filteredPrompts[0]) if filteredPrompts else None + Visibility: SysAdmin sees all, regular user sees own + system prompts. + """ + filteredPrompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId}) + if not filteredPrompts: + return None + + prompt = filteredPrompts[0] + + # Visibility check for non-SysAdmin: must be owner or system prompt + if not self._isSysAdmin(): + isOwner = prompt.get("_createdBy") == self.userId + isSystem = prompt.get("isSystem", False) + if not isOwner and not isSystem: + return None + + return Prompt(**prompt) def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]: """Creates a new prompt if user has permission.""" @@ -669,13 +728,25 @@ class ComponentObjects: return createdRecord def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]: - """Updates a prompt if user has access.""" + """Updates a prompt. Rules: + - SysAdmin: can update any prompt (including system prompts) + - Regular user: can only update own prompts (not system prompts) + """ try: - # Get prompt + # Get prompt (visibility-checked) prompt = self.getPrompt(promptId) if not prompt: raise ValueError(f"Prompt {promptId} not found") + # Permission check: owner or SysAdmin + isOwner = (getattr(prompt, '_createdBy', None) == self.userId) + if not self._isSysAdmin() and not isOwner: + raise PermissionError(f"No permission to update prompt {promptId}") + + # Regular users cannot set isSystem flag + if not self._isSysAdmin() and 'isSystem' in updateData: + del updateData['isSystem'] + # Update prompt record directly with the update data self.db.recordModify(Prompt, promptId, updateData) @@ -688,77 +759,69 @@ class ComponentObjects: return updatedPrompt.model_dump() + except PermissionError: + raise except Exception as e: logger.error(f"Error updating prompt: {str(e)}") raise ValueError(f"Failed to update prompt: {str(e)}") def deletePrompt(self, promptId: str) -> bool: - """Deletes a prompt if user has access.""" - # Check if the prompt exists and user has access + """Deletes a prompt. Rules: + - SysAdmin: can delete any prompt (including system prompts) + - Regular user: can only delete own prompts (not system prompts) + """ + # Get prompt (visibility-checked) prompt = self.getPrompt(promptId) if not prompt: return False - - if not self.checkRbacPermission(Prompt, "update", promptId): + + # Permission check: owner or SysAdmin + isOwner = (getattr(prompt, '_createdBy', None) == self.userId) + if not self._isSysAdmin() and not isOwner: raise PermissionError(f"No permission to delete prompt {promptId}") # Delete prompt success = self.db.recordDelete(Prompt, promptId) - return success # File Utilities - def checkForDuplicateFile(self, fileHash: str, fileName: str = None) -> Optional[FileItem]: - """Checks if a file with the same hash already exists for the current user and mandate. - If fileName is provided, also checks for exact name+hash match. - Only returns files the current user has access to.""" - # Get files with the hash, filtered by RBAC - accessibleFiles = getRecordsetWithRBAC(self.db, + def checkForDuplicateFile(self, fileHash: str, fileName: str) -> Optional[FileItem]: + """Checks if a file with the same hash AND fileName already exists for the current user. + + Duplicate = same user (_createdBy) + same fileHash + same fileName. + Same hash with different name is allowed (intentional copy by user). + Uses direct DB query (not RBAC) because files are isolated per user. + """ + if not self.userId: + return None + + # Direct DB query: find files with matching hash + name + user + matchingFiles = self.db.getRecordset( FileItem, - self.currentUser, - recordFilter={"fileHash": fileHash} + recordFilter={ + "_createdBy": self.userId, + "fileHash": fileHash, + "fileName": fileName + } ) - if not accessibleFiles: + if not matchingFiles: return None - - # If fileName is provided, check for exact name+hash match first - if fileName: - for file in accessibleFiles: - # Skip files without fileName key or with None/empty fileName - if "fileName" not in file or not file["fileName"]: - continue - if file["fileName"] == fileName: - return FileItem( - id=file["id"], - mandateId=file["mandateId"], - fileName=file["fileName"], - mimeType=file["mimeType"], - fileHash=file["fileHash"], - fileSize=file["fileSize"], - creationDate=file["creationDate"] - ) - # Return first valid file with matching hash (for general duplicate detection) - for file in accessibleFiles: - # Skip files without fileName key or with None/empty fileName - if "fileName" not in file or not file["fileName"]: - continue - # Use first valid file - return FileItem( - id=file["id"], - mandateId=file["mandateId"], - fileName=file["fileName"], - mimeType=file["mimeType"], - fileHash=file["fileHash"], - fileSize=file["fileSize"], - creationDate=file["creationDate"] - ) - - # If no valid files found, return None - return None + # Return first match + file = matchingFiles[0] + return FileItem( + id=file["id"], + mandateId=file.get("mandateId", ""), + featureInstanceId=file.get("featureInstanceId", ""), + fileName=file["fileName"], + mimeType=file["mimeType"], + fileHash=file["fileHash"], + fileSize=file["fileSize"], + creationDate=file["creationDate"] + ) def getMimeType(self, fileName: str) -> str: """Determines the MIME type based on the file extension.""" @@ -832,9 +895,18 @@ class ComponentObjects: # File methods - metadata-based operations + def _getFilesByCurrentUser(self, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """Files are always user-scoped. Returns only files owned by the current user, + regardless of role (including SysAdmin). This bypasses RBAC intentionally.""" + filterDict = {"_createdBy": self.userId} + if recordFilter: + filterDict.update(recordFilter) + return self.db.getRecordset(FileItem, recordFilter=filterDict) + def getAllFiles(self, pagination: Optional[PaginationParams] = None) -> Union[List[FileItem], PaginatedResult]: """ - Returns files based on user access level. + Returns files owned by the current user (user-scoped, not RBAC-based). + Every user (including SysAdmin) only sees their own files. Supports optional pagination, sorting, and filtering. Args: @@ -844,13 +916,10 @@ class ComponentObjects: If pagination is None: List[FileItem] If pagination is provided: PaginatedResult with items and metadata """ - # Use RBAC filtering - filteredFiles = getRecordsetWithRBAC(self.db, - FileItem, - self.currentUser - ) + # Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override) + filteredFiles = self._getFilesByCurrentUser() - # Convert database records to FileItem instances (for both paginated and non-paginated) + # Convert database records to FileItem instances (extra='allow' preserves system fields like _createdBy) def convertFileItems(files): fileItems = [] for file in files: @@ -858,21 +927,14 @@ class ComponentObjects: # Ensure proper values, use defaults for invalid data creationDate = file.get("creationDate") if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0: - creationDate = getUtcTimestamp() + file["creationDate"] = getUtcTimestamp() fileName = file.get("fileName") if not fileName or fileName == "None": continue # Skip records with invalid fileName - fileItem = FileItem( - id=file.get("id"), - mandateId=file.get("mandateId"), - fileName=fileName, - mimeType=file.get("mimeType"), - fileHash=file.get("fileHash"), - fileSize=file.get("fileSize"), - creationDate=creationDate - ) + # Use **file to pass all fields including system fields (_createdBy, etc.) + fileItem = FileItem(**file) fileItems.append(fileItem) except Exception as e: logger.warning(f"Skipping invalid file record: {str(e)}") @@ -900,7 +962,7 @@ class ComponentObjects: endIdx = startIdx + pagination.pageSize pagedFiles = filteredFiles[startIdx:endIdx] - # Convert to model objects + # Convert to model objects (extra='allow' on FileItem preserves system fields) items = convertFileItems(pagedFiles) return PaginatedResult( @@ -910,13 +972,9 @@ class ComponentObjects: ) def getFile(self, fileId: str) -> Optional[FileItem]: - """Returns a file by ID if user has access.""" - # Use RBAC filtering - filteredFiles = getRecordsetWithRBAC(self.db, - FileItem, - self.currentUser, - recordFilter={"id": fileId} - ) + """Returns a file by ID if it belongs to the current user (user-scoped).""" + # Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override) + filteredFiles = self._getFilesByCurrentUser(recordFilter={"id": fileId}) if not filteredFiles: return None @@ -976,17 +1034,28 @@ class ComponentObjects: counter += 1 def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem: - """Creates a new file entry if user has permission. Computes fileHash and fileSize from content.""" + """Creates a new file entry if user has permission. Computes fileHash and fileSize from content. + + Duplicate check: if a file with the same user + fileHash + fileName already exists, + the existing file is returned instead of creating a new one. + Same hash with different name is allowed (intentional copy by user). + """ if not self.checkRbacPermission(FileItem, "create"): raise PermissionError("No permission to create files") - # Ensure fileName is unique - uniqueName = self._generateUniquefileName(name) - # Compute file size and hash fileSize = len(content) fileHash = hashlib.sha256(content).hexdigest() + # Duplicate check: same user + same hash + same fileName → return existing + existingFile = self.checkForDuplicateFile(fileHash, name) + if existingFile: + logger.info(f"Duplicate file detected in createFile: '{name}' (hash={fileHash[:12]}...) for user {self.userId} — returning existing file {existingFile.id}") + return existingFile + + # Ensure fileName is unique + uniqueName = self._generateUniquefileName(name) + # Use mandateId and featureInstanceId from context for proper data isolation # Convert None to empty string to satisfy Pydantic validation mandateId = self.mandateId or "" @@ -1005,7 +1074,6 @@ class ComponentObjects: # Store in database self.db.recordCreate(FileItem, fileItem) - return fileItem def updateFile(self, fileId: str, updateData: Dict[str, Any]) -> Dict[str, Any]: @@ -1040,20 +1108,16 @@ class ComponentObjects: if not self.checkRbacPermission(FileItem, "update", fileId): raise PermissionError(f"No permission to delete file {fileId}") - # Check for other references to this file (by hash) - use RBAC to only check files user has access to + # Check for other references to this file (by hash) - user-scoped check fileHash = file.fileHash if fileHash: - allReferences = getRecordsetWithRBAC(self.db, - FileItem, - self.currentUser, - recordFilter={"fileHash": fileHash} - ) + allReferences = self._getFilesByCurrentUser(recordFilter={"fileHash": fileHash}) otherReferences = [f for f in allReferences if f["id"] != fileId] # Only delete associated fileData if no other references exist if not otherReferences: try: - fileDataEntries = getRecordsetWithRBAC(self.db, FileData, self.currentUser, recordFilter={"id": fileId}) + fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId}) if fileDataEntries: self.db.recordDelete(FileData, fileId) logger.debug(f"FileData for file {fileId} deleted") @@ -1113,6 +1177,12 @@ class ComponentObjects: base64Encoded = True logger.debug(f"Stored file {fileId} as base64") + # Check if file data already exists (e.g., when createFile returned a duplicate) + existingData = self.db.getRecordset(FileData, recordFilter={"id": fileId}) + if existingData: + logger.debug(f"File data already exists for {fileId} — skipping duplicate storage") + return True + # Create the fileData record with data and encoding flag fileDataObj = { "id": fileId, @@ -1245,25 +1315,21 @@ class ComponentObjects: logger.error(f"Invalid fileContent type: {type(fileContent)}") raise ValueError(f"fileContent must be bytes, got {type(fileContent)}") - # Compute file hash first to check for duplicates + # Compute file hash to check for duplicates before any DB writes fileHash = hashlib.sha256(fileContent).hexdigest() - # Check for exact name+hash match first (same name + same content) + # Duplicate check: same user + same fileHash + same fileName → return existing file + # Same hash with different name is allowed (intentional copy by user) existingFile = self.checkForDuplicateFile(fileHash, fileName) if existingFile: - logger.info(f"Exact duplicate detected: {fileName} with same hash. Returning existing file reference.") + logger.info(f"Duplicate detected for user {self.userId}: '{fileName}' with hash {fileHash[:12]}... — returning existing file {existingFile.id}") return existingFile, "exact_duplicate" - # Check for hash-only match (same content, different name) - existingFileWithSameHash = self.checkForDuplicateFile(fileHash) - if existingFileWithSameHash: - logger.info(f"Content duplicate detected: {fileName} has same content as {existingFileWithSameHash.fileName}") - # Continue with upload - filename will be made unique if needed - # Determine MIME type mimeType = self.getMimeType(fileName) - # Save metadata and file (hash/size computed inside createFile) + # createFile handles its own duplicate check (for calls from other code paths) + # Here we already checked, so this will create a new file logger.debug(f"Saving file metadata to database for file: {fileName}") fileItem = self.createFile( name=fileName, diff --git a/modules/interfaces/interfaceRbac.py b/modules/interfaces/interfaceRbac.py index 21fd6fa2..313165a0 100644 --- a/modules/interfaces/interfaceRbac.py +++ b/modules/interfaces/interfaceRbac.py @@ -163,7 +163,7 @@ def getRecordsetWithRBAC( # Check view permission first if not permissions.view: - logger.debug(f"User {currentUser.id} has no view permission for {objectKey}") + logger.debug(f"User {currentUser.id} has no view permission for {objectKey} (mandateId={effectiveMandateId}, featureInstanceId={featureInstanceId})") return [] # Build WHERE clause with RBAC filtering diff --git a/modules/routes/routeAdminAutomationEvents.py b/modules/routes/routeAdminAutomationEvents.py index 7765d621..c89f1030 100644 --- a/modules/routes/routeAdminAutomationEvents.py +++ b/modules/routes/routeAdminAutomationEvents.py @@ -90,7 +90,7 @@ async def sync_all_automation_events( from modules.services import getInterface as getServices services = getServices(currentUser, None) - result = await syncAutomationEvents(services, eventUser) + result = syncAutomationEvents(services, eventUser) return { "success": True, "synced": result.get("synced", 0), diff --git a/modules/routes/routeAdminRbacRules.py b/modules/routes/routeAdminRbacRules.py index 1feb64a2..5a639431 100644 --- a/modules/routes/routeAdminRbacRules.py +++ b/modules/routes/routeAdminRbacRules.py @@ -78,11 +78,18 @@ def get_permissions( ) # MULTI-TENANT: Get permissions using context (mandateId/featureInstanceId) + # For DATA context, resolve short model names to full objectKeys + # e.g., "ChatWorkflow" → "data.chat.ChatWorkflow" + resolvedItem = item or "" + if accessContext == AccessRuleContext.DATA and resolvedItem and "." not in resolvedItem: + from modules.interfaces.interfaceRbac import buildDataObjectKey + resolvedItem = buildDataObjectKey(resolvedItem) + # Pass mandateId and featureInstanceId to load Feature-Instance roles permissions = interface.rbac.getUserPermissions( reqContext.user, accessContext, - item or "", + resolvedItem, mandateId=reqContext.mandateId, featureInstanceId=reqContext.featureInstanceId ) diff --git a/modules/routes/routeDataPrompts.py b/modules/routes/routeDataPrompts.py index 4aad221d..faf692fe 100644 --- a/modules/routes/routeDataPrompts.py +++ b/modules/routes/routeDataPrompts.py @@ -121,10 +121,10 @@ def get_prompt( def update_prompt( request: Request, promptId: str = Path(..., description="ID of the prompt to update"), - promptData: Prompt = Body(...), + promptData: Dict[str, Any] = Body(...), currentUser: User = Depends(getCurrentUser) ) -> Prompt: - """Update an existing prompt""" + """Update an existing prompt (supports partial updates for inline editing)""" managementInterface = interfaceDbManagement.getInterface(currentUser) # Check if the prompt exists @@ -135,14 +135,17 @@ def update_prompt( detail=f"Prompt with ID {promptId} not found" ) - # Convert Prompt to dict for interface, excluding the id field - if hasattr(promptData, "model_dump"): - update_data = promptData.model_dump(exclude={"id"}) - else: - update_data = promptData.model_dump(exclude={"id"}) + # Remove id from update data if present + update_data = {k: v for k, v in promptData.items() if k != "id"} - # Update prompt - updatedPrompt = managementInterface.updatePrompt(promptId, update_data) + # Update prompt (ownership check happens in interface) + try: + updatedPrompt = managementInterface.updatePrompt(promptId, update_data) + except PermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) if not updatedPrompt: raise HTTPException( @@ -170,7 +173,14 @@ def delete_prompt( detail=f"Prompt with ID {promptId} not found" ) - success = managementInterface.deletePrompt(promptId) + try: + success = managementInterface.deletePrompt(promptId) + except PermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) + if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/modules/security/rbac.py b/modules/security/rbac.py index 0a2136b1..55048e15 100644 --- a/modules/security/rbac.py +++ b/modules/security/rbac.py @@ -62,7 +62,7 @@ class RbacClass: Multi-Tenant Design: - Lädt Rollen aus UserMandate + UserMandateRole wenn mandateId gegeben - - isSysAdmin gibt vollen Zugriff auf System-Level (kein mandateId) + - isSysAdmin gibt vollen Zugriff, unabhängig vom Kontext Args: user: User object @@ -82,8 +82,8 @@ class RbacClass: delete=AccessLevel.NONE ) - # SysAdmin auf System-Level (kein Mandant) hat vollen Zugriff - if hasattr(user, 'isSysAdmin') and user.isSysAdmin and not mandateId: + # SysAdmin hat vollen Zugriff - unabhängig vom Kontext (Mandant/Feature) + if hasattr(user, 'isSysAdmin') and user.isSysAdmin: return UserPermissions( view=True, read=AccessLevel.ALL, @@ -96,6 +96,7 @@ class RbacClass: roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId) if not roleIds: + logger.debug(f"getUserPermissions: NO roles found for user={user.id}, mandateId={mandateId}, featureInstanceId={featureInstanceId}, item={item}") return permissions # Lade alle relevanten Regeln für alle Rollen diff --git a/modules/services/__init__.py b/modules/services/__init__.py index 3e0a2560..6712372a 100644 --- a/modules/services/__init__.py +++ b/modules/services/__init__.py @@ -63,10 +63,11 @@ class Services: - Feature-specific Services are loaded dynamically via filename discovery """ - def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None): + def __init__(self, user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None): self.user: User = user self.workflow = workflow self.mandateId: Optional[str] = mandateId + self.featureInstanceId: Optional[str] = featureInstanceId self.currentUserPrompt: str = "" self.rawUserPrompt: str = "" @@ -83,7 +84,7 @@ class Services: # CENTRAL INTERFACE (Chat/Workflow) # ============================================================ from modules.interfaces.interfaceDbChat import getInterface as getChatInterface - self.interfaceDbChat = getChatInterface(user, mandateId=mandateId) + self.interfaceDbChat = getChatInterface(user, mandateId=mandateId, featureInstanceId=featureInstanceId) # ============================================================ # SHARED SERVICES (from modules/services/) @@ -143,7 +144,7 @@ class Services: # Get interface via getInterface() if hasattr(module, "getInterface"): - interface = module.getInterface(self.user, mandateId=self.mandateId) + interface = module.getInterface(self.user, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId) # Derive attribute name: interfaceFeatureAiChat -> interfaceDbChat attrName = filename.replace("interfaceFeature", "interfaceDb") setattr(self, attrName, interface) @@ -191,6 +192,6 @@ class Services: logger.debug(f"Could not load service from {filepath}: {e}") -def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None) -> Services: - """Get Services instance for the given user and mandate context.""" - return Services(user, workflow, mandateId=mandateId) +def getInterface(user: User, workflow: "ChatWorkflow" = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> Services: + """Get Services instance for the given user, mandate, and feature instance context.""" + return Services(user, workflow, mandateId=mandateId, featureInstanceId=featureInstanceId) diff --git a/modules/services/serviceAi/subStructureFilling.py b/modules/services/serviceAi/subStructureFilling.py index 9b503567..a96b8353 100644 --- a/modules/services/serviceAi/subStructureFilling.py +++ b/modules/services/serviceAi/subStructureFilling.py @@ -2574,8 +2574,8 @@ CRITICAL: """ from modules.services.serviceGeneration.renderers.registry import getRenderer - # Get renderer for this format - NO FALLBACK - renderer = getRenderer(outputFormat, self.services) + # Get document renderer for this format (structure filling is document generation path) + renderer = getRenderer(outputFormat, self.services, outputStyle='document') if not renderer: raise ValueError(f"No renderer found for output format '{outputFormat}'. Check renderer registry.") diff --git a/modules/services/serviceGeneration/mainServiceGeneration.py b/modules/services/serviceGeneration/mainServiceGeneration.py index 4720c9a0..447b7f9d 100644 --- a/modules/services/serviceGeneration/mainServiceGeneration.py +++ b/modules/services/serviceGeneration/mainServiceGeneration.py @@ -556,10 +556,10 @@ class GenerationService: def _getFormatRenderer(self, output_format: str): - """Get the appropriate renderer for the specified format using auto-discovery.""" + """Get the appropriate document renderer for the specified format.""" try: from .renderers.registry import getRenderer, getSupportedFormats - renderer = getRenderer(output_format, services=self.services) + renderer = getRenderer(output_format, services=self.services, outputStyle='document') if renderer: return renderer @@ -573,7 +573,7 @@ class GenerationService: # Fallback to text renderer if no specific renderer found logger.warning(f"Falling back to text renderer for format {output_format}") - fallbackRenderer = getRenderer('text', services=self.services) + fallbackRenderer = getRenderer('text', services=self.services, outputStyle='document') if fallbackRenderer: return fallbackRenderer diff --git a/modules/services/serviceGeneration/paths/codePath.py b/modules/services/serviceGeneration/paths/codePath.py index f2470385..d43d275f 100644 --- a/modules/services/serviceGeneration/paths/codePath.py +++ b/modules/services/serviceGeneration/paths/codePath.py @@ -922,7 +922,7 @@ CRITICAL: """Get code renderer for file type.""" from modules.services.serviceGeneration.renderers.registry import getRenderer - # Map file types to renderer formats + # Map file types to renderer formats (code path) formatMap = { 'json': 'json', 'csv': 'csv', @@ -931,7 +931,7 @@ CRITICAL: rendererFormat = formatMap.get(fileType.lower()) if rendererFormat: - renderer = getRenderer(rendererFormat, self.services) + renderer = getRenderer(rendererFormat, self.services, outputStyle='code') # Check if renderer supports code rendering if renderer and hasattr(renderer, 'renderCodeFiles'): return renderer diff --git a/modules/services/serviceGeneration/renderers/registry.py b/modules/services/serviceGeneration/renderers/registry.py index c7e2d9f6..b0c96e80 100644 --- a/modules/services/serviceGeneration/renderers/registry.py +++ b/modules/services/serviceGeneration/renderers/registry.py @@ -2,20 +2,30 @@ # All rights reserved. """ Renderer registry for automatic discovery and registration of renderers. + +Renderers are indexed by (format, outputStyle) so that document generation +and code generation each get the correct renderer for the same format. """ import logging import importlib -from typing import Dict, Type, List, Optional +from typing import Dict, Type, List, Optional, Tuple from .documentRendererBaseTemplate import BaseRenderer logger = logging.getLogger(__name__) + class RendererRegistry: - """Registry for automatic renderer discovery and management.""" + """Registry for automatic renderer discovery and management. + + Maintains separate renderer mappings per outputStyle ('document', 'code', etc.) + so that document-generation and code-generation paths each resolve to the + correct renderer, even when both support the same format (e.g. 'csv'). + """ def __init__(self): - self._renderers: Dict[str, Type[BaseRenderer]] = {} + # Key: (formatName, outputStyle) -> rendererClass + self._renderers: Dict[Tuple[str, str], Type[BaseRenderer]] = {} self._format_mappings: Dict[str, str] = {} self._discovered = False @@ -25,39 +35,27 @@ class RendererRegistry: return try: - import os - import sys from pathlib import Path - # Get the directory containing this registry file currentDir = Path(__file__).parent - renderersDir = currentDir - - # Get the package name dynamically packageName = __name__.rsplit('.', 1)[0] - # Scan all Python files in the renderers directory - for filePath in renderersDir.glob("*.py"): - if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', '__init__.py']: + for filePath in currentDir.glob("*.py"): + if filePath.name in ['registry.py', 'documentRendererBaseTemplate.py', 'codeRendererBaseTemplate.py', '__init__.py']: continue - # Extract module name from filename moduleName = filePath.stem try: - # Import the module dynamically fullModuleName = f"{packageName}.{moduleName}" module = importlib.import_module(fullModuleName) - # Look for renderer classes in the module for attrName in dir(module): attr = getattr(module, attrName) if (isinstance(attr, type) and issubclass(attr, BaseRenderer) and attr != BaseRenderer and hasattr(attr, 'getSupportedFormats')): - - # Register the renderer self._registerRendererClass(attr) except Exception as e: @@ -68,60 +66,75 @@ class RendererRegistry: except Exception as e: logger.error(f"Error during renderer discovery: {str(e)}") - self._discovered = True # Mark as discovered to avoid repeated attempts + self._discovered = True def _registerRendererClass(self, rendererClass: Type[BaseRenderer]) -> None: - """Register a renderer class with its supported formats.""" + """Register a renderer class keyed by (format, outputStyle).""" try: - # Get supported formats from the renderer class supportedFormats = rendererClass.getSupportedFormats() - - # Get priority (default to 0 if not specified) + outputStyle = rendererClass.getOutputStyle() if hasattr(rendererClass, 'getOutputStyle') else 'document' priority = rendererClass.getPriority() if hasattr(rendererClass, 'getPriority') else 0 for formatName in supportedFormats: formatKey = formatName.lower() + registryKey = (formatKey, outputStyle) - # Check if format already registered - use priority to decide - if formatKey in self._renderers: - existingRenderer = self._renderers[formatKey] + if registryKey in self._renderers: + existingRenderer = self._renderers[registryKey] existingPriority = existingRenderer.getPriority() if hasattr(existingRenderer, 'getPriority') else 0 - # Only replace if new renderer has higher priority if priority > existingPriority: - logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for format '{formatName}' (priority {priority} > {existingPriority})") - self._renderers[formatKey] = rendererClass + logger.debug(f"Replacing {existingRenderer.__name__} with {rendererClass.__name__} for ({formatKey}, {outputStyle}) (priority {priority} > {existingPriority})") + self._renderers[registryKey] = rendererClass else: - logger.debug(f"Keeping {existingRenderer.__name__} for format '{formatName}' (priority {existingPriority} >= {priority})") + logger.debug(f"Keeping {existingRenderer.__name__} for ({formatKey}, {outputStyle}) (priority {existingPriority} >= {priority})") else: - # Register primary format - self._renderers[formatKey] = rendererClass + self._renderers[registryKey] = rendererClass - # Register aliases if any + # Register aliases if hasattr(rendererClass, 'getFormatAliases'): aliases = rendererClass.getFormatAliases() for alias in aliases: - self._format_mappings[alias.lower()] = formatName.lower() + self._format_mappings[alias.lower()] = formatKey - logger.debug(f"Registered {rendererClass.__name__} for formats: {supportedFormats} (priority: {priority})") + logger.debug(f"Registered {rendererClass.__name__} for formats={supportedFormats}, style={outputStyle}, priority={priority}") except Exception as e: logger.error(f"Error registering renderer {rendererClass.__name__}: {str(e)}") - def getRenderer(self, outputFormat: str, services=None) -> Optional[BaseRenderer]: - """Get a renderer instance for the specified format.""" + def getRenderer(self, outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]: + """Get a renderer instance for the specified format and style. + + Args: + outputFormat: Format name (e.g. 'csv', 'json', 'pdf') + services: Services instance passed to renderer constructor + outputStyle: 'document' or 'code'. If None, returns the first match + with preference: document > code (most callers are document path). + """ if not self._discovered: self.discoverRenderers() - # Normalize format name formatName = outputFormat.lower().strip() - - # Check for aliases first if formatName in self._format_mappings: formatName = self._format_mappings[formatName] - # Get renderer class - rendererClass = self._renderers.get(formatName) + rendererClass = None + + if outputStyle: + # Exact match by style + rendererClass = self._renderers.get((formatName, outputStyle)) + else: + # No style specified — prefer 'document', then 'code', then any + for style in ['document', 'code']: + rendererClass = self._renderers.get((formatName, style)) + if rendererClass: + break + # Fallback: check any registered style + if not rendererClass: + for key, cls in self._renderers.items(): + if key[0] == formatName: + rendererClass = cls + break if rendererClass: try: @@ -130,7 +143,7 @@ class RendererRegistry: logger.error(f"Error creating renderer instance for {formatName}: {str(e)}") return None - logger.warning(f"No renderer found for format: {outputFormat}") + logger.warning(f"No renderer found for format={outputFormat}, style={outputStyle}") return None def getSupportedFormats(self) -> List[str]: @@ -138,9 +151,11 @@ class RendererRegistry: if not self._discovered: self.discoverRenderers() - formats = list(self._renderers.keys()) - formats.extend(self._format_mappings.keys()) - return sorted(set(formats)) + formats = set() + for (fmt, _style) in self._renderers.keys(): + formats.add(fmt) + formats.update(self._format_mappings.keys()) + return sorted(formats) def getRendererInfo(self) -> Dict[str, Dict[str, str]]: """Get information about all registered renderers.""" @@ -148,10 +163,12 @@ class RendererRegistry: self.discoverRenderers() info = {} - for formatName, rendererClass in self._renderers.items(): - info[formatName] = { + for (formatName, style), rendererClass in self._renderers.items(): + key = f"{formatName}:{style}" + info[key] = { 'class_name': rendererClass.__name__, 'module': rendererClass.__module__, + 'outputStyle': style, 'description': getattr(rendererClass, '__doc__', 'No description').strip().split('\n')[0] if rendererClass.__doc__ else 'No description' } @@ -160,44 +177,62 @@ class RendererRegistry: def getOutputStyle(self, outputFormat: str) -> Optional[str]: """ Get the output style classification for a given format. - Returns: 'code', 'document', 'image', or other (e.g., 'video' for future use) + When both 'document' and 'code' renderers exist for a format, + returns the default ('document') since this is called during document generation. """ if not self._discovered: self.discoverRenderers() - # Normalize format name formatName = outputFormat.lower().strip() - - # Check for aliases first if formatName in self._format_mappings: formatName = self._format_mappings[formatName] - # Get renderer class and call getOutputStyle (all renderers have same signature) - rendererClass = self._renderers.get(formatName) - try: - return rendererClass.getOutputStyle(formatName) - except (AttributeError, TypeError) as e: - logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style") - return None - except Exception as e: - logger.warning(f"Error getting output style for {outputFormat}: {str(e)}") - return None + # Check document first, then code + for style in ['document', 'code']: + rendererClass = self._renderers.get((formatName, style)) + if rendererClass: + try: + return rendererClass.getOutputStyle(formatName) + except Exception: + pass + + # Fallback: any style + for key, rendererClass in self._renderers.items(): + if key[0] == formatName: + try: + return rendererClass.getOutputStyle(formatName) + except Exception: + pass + + logger.warning(f"No renderer found for format: {outputFormat}, cannot determine output style") + return None + # Global registry instance _registry = RendererRegistry() -def getRenderer(outputFormat: str, services=None) -> Optional[BaseRenderer]: - """Get a renderer instance for the specified format.""" - return _registry.getRenderer(outputFormat, services) + +def getRenderer(outputFormat: str, services=None, outputStyle: str = None) -> Optional[BaseRenderer]: + """Get a renderer instance for the specified format and style. + + Args: + outputFormat: Format name (e.g. 'csv', 'json', 'pdf') + services: Services instance + outputStyle: 'document' or 'code'. If None, prefers document renderer. + """ + return _registry.getRenderer(outputFormat, services, outputStyle=outputStyle) + def getSupportedFormats() -> List[str]: """Get list of all supported formats.""" return _registry.getSupportedFormats() + def getRendererInfo() -> Dict[str, Dict[str, str]]: """Get information about all registered renderers.""" return _registry.getRendererInfo() + def getOutputStyle(outputFormat: str) -> Optional[str]: """Get the output style classification for a given format.""" return _registry.getOutputStyle(outputFormat) diff --git a/modules/services/serviceGeneration/renderers/rendererCsv.py b/modules/services/serviceGeneration/renderers/rendererCsv.py index 45871922..91312299 100644 --- a/modules/services/serviceGeneration/renderers/rendererCsv.py +++ b/modules/services/serviceGeneration/renderers/rendererCsv.py @@ -35,9 +35,9 @@ class RendererCsv(BaseRenderer): def getAcceptedSectionTypes(cls, formatName: Optional[str] = None) -> List[str]: """ Return list of section content types that CSV renderer accepts. - CSV renderer only accepts table sections. + CSV renderer accepts table sections and code_block sections (for raw CSV content). """ - return ["table"] + return ["table", "code_block"] async def render(self, extractedContent: Dict[str, Any], title: str, userPrompt: str = None, aiService=None) -> List[RenderedDocument]: """Render extracted JSON content to CSV format. Produces one CSV file per table section.""" @@ -62,16 +62,24 @@ class RendererCsv(BaseRenderer): if baseFilename.endswith('.csv'): baseFilename = baseFilename[:-4] - # Find all table sections + # Collect CSV-producing sections: table sections AND code_block sections with CSV language tableSections = [] + codeBlockCsvSections = [] for section in sections: sectionType = section.get("content_type", "paragraph") if sectionType == "table": tableSections.append(section) + elif sectionType == "code_block": + # Check if any element is a code_block with language "csv" + for element in section.get("elements", []): + content = element.get("content", {}) + if isinstance(content, dict) and content.get("language", "").lower() == "csv": + codeBlockCsvSections.append(section) + break - # If no table sections found, return empty CSV - if not tableSections: - self.logger.warning("No table sections found in CSV document - returning empty CSV") + # If no usable sections found, return empty CSV + if not tableSections and not codeBlockCsvSections: + self.logger.warning("No table or CSV code_block sections found in CSV document - returning empty CSV") emptyCsv = self._convertRowsToCsv([["No table data available"]]) return [ RenderedDocument( @@ -83,45 +91,52 @@ class RendererCsv(BaseRenderer): ) ] - # Generate one CSV file per table section + allCsvSections = tableSections + codeBlockCsvSections + + # Generate one CSV file per section renderedDocuments = [] - for i, tableSection in enumerate(tableSections): - # Generate CSV content for this table section - csvRows = [] + for i, csvSection in enumerate(allCsvSections): + sectionType = csvSection.get("content_type", "paragraph") + sectionTitle = csvSection.get("title") + csvContent = "" - # Add section title if available - sectionTitle = tableSection.get("title") - if sectionTitle: - csvRows.append([sectionTitle]) - csvRows.append([]) # Empty row after title + if sectionType == "code_block": + # Extract raw CSV content directly from code_block elements + rawCsvParts = [] + for element in csvSection.get("elements", []): + content = element.get("content", {}) + if isinstance(content, dict) and content.get("language", "").lower() == "csv": + code = content.get("code", "") + if code: + rawCsvParts.append(code) + csvContent = "\n".join(rawCsvParts) + else: + # Table section — render via table logic + csvRows = [] + if sectionTitle: + csvRows.append([sectionTitle]) + csvRows.append([]) # Empty row after title + + elements = csvSection.get("elements", []) + for element in elements: + tableRows = self._renderJsonTableToCsv(element) + if tableRows: + csvRows.extend(tableRows) + + csvContent = self._convertRowsToCsv(csvRows) - # Render table from section elements - elements = tableSection.get("elements", []) - for element in elements: - tableRows = self._renderJsonTableToCsv(element) - if tableRows: - csvRows.extend(tableRows) - - # Convert to CSV string - csvContent = self._convertRowsToCsv(csvRows) - - # Determine filename for this table - if len(tableSections) == 1: - # Single table - use base filename + # Determine filename + if len(allCsvSections) == 1: filename = f"{baseFilename}.csv" else: - # Multiple tables - add index or section title to filename - sectionId = tableSection.get("id", f"table_{i+1}") - # Use section title if available, otherwise use section ID + sectionId = csvSection.get("id", f"csv_{i+1}") if sectionTitle: - # Sanitize section title for filename safeTitle = "".join(c for c in sectionTitle if c.isalnum() or c in (' ', '-', '_')).strip() - safeTitle = safeTitle.replace(' ', '_')[:30] # Limit length + safeTitle = safeTitle.replace(' ', '_')[:30] filename = f"{baseFilename}_{safeTitle}.csv" else: filename = f"{baseFilename}_{sectionId}.csv" - # Extract document type from metadata documentType = metadata.get("documentType") if isinstance(metadata, dict) else None renderedDocuments.append( diff --git a/modules/shared/callbackRegistry.py b/modules/shared/callbackRegistry.py index 23eb84ab..361f4e1d 100644 --- a/modules/shared/callbackRegistry.py +++ b/modules/shared/callbackRegistry.py @@ -9,7 +9,6 @@ Features can register callbacks to be notified when automations change. import logging from typing import Callable, List, Dict, Any -import asyncio logger = logging.getLogger(__name__) @@ -25,7 +24,7 @@ class CallbackRegistry: Args: event_type: Type of event (e.g., 'automation.changed') - callback: Async or sync callback function + callback: Sync callback function """ if event_type not in self._callbacks: self._callbacks[event_type] = [] @@ -41,8 +40,8 @@ class CallbackRegistry: except ValueError: logger.warning(f"Callback not found for event type: {event_type}") - async def trigger(self, event_type: str, *args, **kwargs): - """Trigger all callbacks registered for an event type. + def trigger(self, event_type: str, *args, **kwargs): + """Trigger all registered callbacks for an event type. Args: event_type: Type of event to trigger @@ -55,18 +54,14 @@ class CallbackRegistry: for callback in callbacks: try: - if asyncio.iscoroutinefunction(callback): - await callback(*args, **kwargs) - else: - callback(*args, **kwargs) + callback(*args, **kwargs) except Exception as e: logger.error(f"Error executing callback for {event_type}: {str(e)}", exc_info=True) - def has_callbacks(self, event_type: str) -> bool: + def hasCallbacks(self, event_type: str) -> bool: """Check if there are any callbacks registered for an event type.""" return event_type in self._callbacks and len(self._callbacks[event_type]) > 0 # Global singleton instance callbackRegistry = CallbackRegistry() - diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index 06d36dae..6e34c3cb 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -38,16 +38,14 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode featureCode: Feature code (e.g., 'chatplayground', 'automation') """ try: - services = getServices(currentUser, mandateId=mandateId) + services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId) # Store allowedProviders in services context for model selection if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders: services.allowedProviders = userInput.allowedProviders logger.info(f"AI provider filter active: {userInput.allowedProviders}") - # Store feature context in services (for billing and RBAC) - if featureInstanceId: - services.featureInstanceId = featureInstanceId + # Store feature code in services (for billing) if featureCode: services.featureCode = featureCode @@ -61,10 +59,8 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow: """Stops a running chat.""" try: - services = getServices(currentUser, mandateId=mandateId) - # Store feature instance ID in services context for proper RBAC filtering + services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId) if featureInstanceId: - services.featureInstanceId = featureInstanceId services.featureCode = 'chatplayground' workflowManager = WorkflowManager(services) return await workflowManager.workflowStop(workflowId) @@ -73,12 +69,17 @@ async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] raise -async def executeAutomation(automationId: str, services) -> ChatWorkflow: - """Execute automation workflow immediately (test mode) with placeholder replacement. +async def executeAutomation(automationId: str, automation, creatorUser: User, services) -> ChatWorkflow: + """Execute automation workflow with the creator user's context. + + The automation object and creatorUser are resolved by the caller (handler) + using the SysAdmin eventUser. This function does NOT re-load them. Args: automationId: ID of automation to execute - services: Services instance for data access + automation: Pre-loaded automation object (with system fields like _createdBy) + creatorUser: The user who created the automation (workflow runs in this context) + services: Services instance (used for interfaceDbApp etc.) Returns: ChatWorkflow instance created by automation execution @@ -92,11 +93,6 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: } try: - # 1. Load automation definition (with system fields for _createdBy access) - automation = services.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True) - if not automation: - raise ValueError(f"Automation {automationId} not found") - executionLog["messages"].append(f"Started execution at {executionStartTime}") # Store allowed providers from automation in services context @@ -105,12 +101,12 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}") # Context comes EXCLUSIVELY from the automation definition - services.mandateId = str(automation.mandateId) - services.featureInstanceId = str(automation.featureInstanceId) - services.featureCode = 'automation' - featureInstanceId = services.featureInstanceId + automationMandateId = str(automation.mandateId) + automationFeatureInstanceId = str(automation.featureInstanceId) - # 2. Replace placeholders in template to generate plan + logger.info(f"Executing automation {automationId} as user {creatorUser.id} with mandateId={automationMandateId}, featureInstanceId={automationFeatureInstanceId}") + + # 1. Replace placeholders in template to generate plan template = automation.template or "" placeholders = automation.placeholders or {} planJson = replacePlaceholders(template, placeholders) @@ -128,24 +124,9 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: logger.error(f"Context around error: ...{planJson[start:end]}...") raise ValueError(f"Invalid JSON after placeholder replacement: {str(e)}") executionLog["messages"].append("Template placeholders replaced successfully") + executionLog["messages"].append(f"Using creator user: {creatorUser.id}") - # 3. Get user who created automation - creatorUserId = getattr(automation, "_createdBy", None) - - # _createdBy is a system attribute - must be present - if not creatorUserId: - errorMsg = f"Automation {automationId} has no creator user (_createdBy field missing). Cannot execute automation." - logger.error(errorMsg) - executionLog["messages"].append(errorMsg) - raise ValueError(errorMsg) - - # Get creator user from database - creatorUser = services.interfaceDbApp.getUser(creatorUserId) - if not creatorUser: - raise ValueError(f"Creator user {creatorUserId} not found") - executionLog["messages"].append(f"Using creator user: {creatorUserId}") - - # 4. Create UserInputRequest from plan + # 2. Create UserInputRequest from plan # Embed plan JSON in prompt for TemplateMode to extract promptText = planToPrompt(plan) planJsonStr = json.dumps(plan) @@ -160,16 +141,15 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: executionLog["messages"].append("Starting workflow execution") - # 5. Start workflow using chatStart - # Pass mandateId, featureInstanceId, and featureCode from original services context - # so billing is recorded correctly with full feature context + # 3. Start workflow using chatStart with creator's context + # mandateId and featureInstanceId come from the automation definition workflow = await chatStart( currentUser=creatorUser, userInput=userInput, workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION, workflowId=None, - mandateId=services.mandateId, - featureInstanceId=featureInstanceId, + mandateId=automationMandateId, + featureInstanceId=automationFeatureInstanceId, featureCode='automation' ) @@ -200,22 +180,22 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: executionLog["messages"].append(f"Error: {str(e)}") # Save execution log even on error (bypasses RBAC — system operation) + # Use the automation object already passed in (no re-load needed) try: - automation = services.interfaceDbAutomation.getAutomationDefinition(automationId) - if automation: - executionLogs = list(automation.executionLogs or []) - executionLogs.append(executionLog) - if len(executionLogs) > 50: - executionLogs = executionLogs[-50:] - services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs) + executionLogs = list(getattr(automation, 'executionLogs', None) or []) + executionLogs.append(executionLog) + if len(executionLogs) > 50: + executionLogs = executionLogs[-50:] + services.interfaceDbAutomation._saveExecutionLog(automationId, executionLogs) except Exception as logError: logger.error(f"Error saving execution log: {str(logError)}") raise -async def syncAutomationEvents(services, eventUser) -> Dict[str, Any]: - """Automation event handler - syncs scheduler with all active automations. +def syncAutomationEvents(services, eventUser) -> Dict[str, Any]: + """Sync scheduler with all active automations. + All operations (DB reads, scheduler registration) are synchronous. Args: services: Services instance for data access @@ -316,37 +296,28 @@ def createAutomationEventHandler(automationId: str, eventUser): logger.error("Event user not available for automation execution") return - # Get services for event user (provides access to interfaces) + # Load automation using SysAdmin eventUser (has unrestricted access) eventServices = getServices(eventUser, None) - - # Load automation using event user context (with system fields for _createdBy access) automation = eventServices.interfaceDbAutomation.getAutomationDefinition(automationId, includeSystemFields=True) if not automation or not getattr(automation, "active", False): logger.warning(f"Automation {automationId} not found or not active, skipping execution") return - # Get creator user + # Get creator user ID from automation's _createdBy system field creatorUserId = getattr(automation, "_createdBy", None) if not creatorUserId: - logger.error(f"Automation {automationId} has no creator user") + logger.error(f"Automation {automationId} has no creator user (_createdBy missing)") return - # Get mandate context from automation definition - automationMandateId = getattr(automation, "mandateId", None) - - # Get creator user from database using services - eventServices = getServices(eventUser, None) + # Get creator user from database (using SysAdmin access) creatorUser = eventServices.interfaceDbApp.getUser(creatorUserId) if not creatorUser: logger.error(f"Creator user {creatorUserId} not found for automation {automationId}") return - # Get services for creator user WITH mandate context from automation - creatorServices = getServices(creatorUser, automationMandateId) - - # Execute automation with creator user's context and mandate - # executeAutomation is in same module, so we can call it directly - await executeAutomation(automationId, creatorServices) + # Execute automation — pass automation object and creatorUser directly + # No re-load needed in executeAutomation + await executeAutomation(automationId, automation, creatorUser, eventServices) logger.info(f"Successfully executed automation {automationId} as user {creatorUserId}") except Exception as e: logger.error(f"Error executing automation {automationId}: {str(e)}") diff --git a/modules/workflows/automation/subAutomationSchedule.py b/modules/workflows/automation/subAutomationSchedule.py index 1061d65e..40638461 100644 --- a/modules/workflows/automation/subAutomationSchedule.py +++ b/modules/workflows/automation/subAutomationSchedule.py @@ -14,9 +14,10 @@ from modules.services import getInterface as getServices logger = logging.getLogger(__name__) -async def start(eventUser) -> None: +def start(eventUser) -> bool: """ Start automation scheduler and sync scheduled events. + All operations are synchronous (DB access, scheduler registration). Args: eventUser: System-level event user for background operations (provided by app.py) @@ -33,16 +34,16 @@ async def start(eventUser) -> None: services = getServices(eventUser, None) # Register callback for automation changes - async def onAutomationChanged(chatInterface): + def onAutomationChanged(chatInterface): """Callback triggered when automations are created/updated/deleted.""" eventServices = getServices(eventUser, None) - await syncAutomationEvents(eventServices, eventUser) + syncAutomationEvents(eventServices, eventUser) callbackRegistry.register('automation.changed', onAutomationChanged) logger.info("Automation: Registered change callback") # Initial sync on startup - await syncAutomationEvents(services, eventUser) + syncAutomationEvents(services, eventUser) logger.info("Automation: Scheduled events synced on startup") except Exception as e: @@ -52,7 +53,7 @@ async def start(eventUser) -> None: return True -async def stop(eventUser) -> None: +def stop(eventUser) -> bool: """ Stop automation scheduler. diff --git a/modules/workflows/methods/methodBase.py b/modules/workflows/methods/methodBase.py index 7934ea19..173023f1 100644 --- a/modules/workflows/methods/methodBase.py +++ b/modules/workflows/methods/methodBase.py @@ -139,11 +139,16 @@ class MethodBase: return False # RBAC-Check: RESOURCE context, item = actionId + # mandateId/featureInstanceId from services context needed to resolve user roles try: + mandateId = getattr(self.services, 'mandateId', None) + featureInstanceId = getattr(self.services, 'featureInstanceId', None) permissions = self.services.rbac.getUserPermissions( user=currentUser, context=AccessRuleContext.RESOURCE, - item=actionId + item=actionId, + mandateId=str(mandateId) if mandateId else None, + featureInstanceId=str(featureInstanceId) if featureInstanceId else None ) hasPermission = permissions.view if not hasPermission: @@ -151,8 +156,9 @@ class MethodBase: userRoles = getattr(currentUser, 'roleLabels', []) or [] self.logger.warning( f"RBAC denied action {actionId} for user {currentUser.id}. " - f"User roles: {userRoles}, " - f"Permissions: view={permissions.view}, edit={permissions.edit}, delete={permissions.delete}. " + f"User roles: {userRoles}, mandateId={mandateId}, " + f"Permissions: view={permissions.view}, read={permissions.read}, " + f"create={permissions.create}, update={permissions.update}, delete={permissions.delete}. " f"No matching RBAC rule found for context=RESOURCE, item={actionId}" ) return hasPermission From ab48e2e853ebaeebdf873e3d733ebc46cc551b93 Mon Sep 17 00:00:00 2001 From: patrick-motsch Date: Tue, 10 Feb 2026 00:10:07 +0100 Subject: [PATCH 18/18] enhanced generic navigation tree --- modules/datamodels/datamodelMembership.py | 2 +- modules/datamodels/datamodelRbac.py | 2 +- modules/datamodels/datamodelUam.py | 8 +++---- modules/interfaces/interfaceBootstrap.py | 29 +++++++++++++++++++++++ modules/interfaces/interfaceDbApp.py | 4 ++-- modules/routes/routeBilling.py | 2 +- modules/routes/routeDataMandates.py | 4 ++-- modules/routes/routeInvitations.py | 6 ++--- modules/routes/routeSystem.py | 2 +- 9 files changed, 44 insertions(+), 15 deletions(-) diff --git a/modules/datamodels/datamodelMembership.py b/modules/datamodels/datamodelMembership.py index e100b23c..5e8b8814 100644 --- a/modules/datamodels/datamodelMembership.py +++ b/modules/datamodels/datamodelMembership.py @@ -28,7 +28,7 @@ class UserMandate(BaseModel): ) mandateId: str = Field( description="FK → Mandate.id (CASCADE DELETE)", - json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "name"} + json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "label"} ) enabled: bool = Field( default=True, diff --git a/modules/datamodels/datamodelRbac.py b/modules/datamodels/datamodelRbac.py index 64ad56a4..978c3be6 100644 --- a/modules/datamodels/datamodelRbac.py +++ b/modules/datamodels/datamodelRbac.py @@ -55,7 +55,7 @@ class Role(BaseModel): mandateId: Optional[str] = Field( default=None, description="FK → Mandate.id (CASCADE DELETE). Null = Global/Template role.", - json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_visible": True, "frontend_required": False, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "name"} + json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_visible": True, "frontend_required": False, "frontend_fk_source": "/api/mandates/", "frontend_fk_display_field": "label"} ) featureInstanceId: Optional[str] = Field( default=None, diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py index 6b7cdc06..22d94ebe 100644 --- a/modules/datamodels/datamodelUam.py +++ b/modules/datamodels/datamodelUam.py @@ -73,10 +73,10 @@ class Mandate(BaseModel): description="Name of the mandate", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True} ) - description: Optional[str] = Field( + label: Optional[str] = Field( default=None, - description="Description of the mandate", - json_schema_extra={"frontend_type": "textarea", "frontend_readonly": False, "frontend_required": False} + description="Display label of the mandate", + json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False} ) enabled: bool = Field( default=True, @@ -104,7 +104,7 @@ registerModelLabels( { "id": {"en": "ID", "de": "ID", "fr": "ID"}, "name": {"en": "Name", "de": "Name", "fr": "Nom"}, - "description": {"en": "Description", "de": "Beschreibung", "fr": "Description"}, + "label": {"en": "Label", "de": "Label", "fr": "Libellé"}, "enabled": {"en": "Enabled", "de": "Aktiviert", "fr": "Activé"}, "isSystem": {"en": "System Mandate", "de": "System-Mandant", "fr": "Mandat système"}, }, diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index 8a58f352..3f678a2b 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -51,6 +51,9 @@ def initBootstrap(db: DatabaseConnector) -> None: # Initialize root mandate mandateId = initRootMandate(db) + # Migrate existing mandate records: description -> label + _migrateMandateDescriptionToLabel(db) + # Initialize system role TEMPLATES (mandateId=None, isSystemRole=True) initRoles(db) @@ -276,6 +279,32 @@ def initRootMandate(db: DatabaseConnector) -> Optional[str]: return mandateId +def _migrateMandateDescriptionToLabel(db: DatabaseConnector) -> None: + """ + Migration: Rename 'description' field to 'label' in all Mandate records. + Copies existing 'description' values to 'label' and removes the old field. + Safe to run multiple times (idempotent). + """ + allMandates = db.getRecordset(Mandate) + migratedCount = 0 + for mandateRecord in allMandates: + mandateId = mandateRecord.get("id") + hasDescription = "description" in mandateRecord and mandateRecord.get("description") is not None + hasLabel = "label" in mandateRecord and mandateRecord.get("label") is not None + + if hasDescription and not hasLabel: + # Copy description to label + updateData = {"label": mandateRecord["description"]} + db.recordModify(Mandate, mandateId, updateData) + migratedCount += 1 + logger.info(f"Migrated mandate {mandateId}: description -> label") + + if migratedCount > 0: + logger.info(f"Migrated {migratedCount} mandate(s) from description to label") + else: + logger.debug("No mandate description->label migration needed") + + def initAdminUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[str]: """ Creates the Admin user if it doesn't exist. diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 2ca8232b..d6e4c6c0 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -1444,7 +1444,7 @@ class AppObjects: return Mandate(**filteredMandates[0]) - def createMandate(self, name: str, description: str = None, enabled: bool = True) -> Mandate: + def createMandate(self, name: str, label: str = None, enabled: bool = True) -> Mandate: """ Creates a new mandate if user has permission. Automatically copies system template roles (admin, user, viewer) to the new mandate. @@ -1453,7 +1453,7 @@ class AppObjects: raise PermissionError("No permission to create mandates") # Create mandate data using model - mandateData = Mandate(name=name, description=description, enabled=enabled) + mandateData = Mandate(name=name, label=label, enabled=enabled) # Create mandate record createdRecord = self.db.recordCreate(Mandate, mandateData) diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index e785ed13..586cc6dd 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -355,7 +355,7 @@ def getBalanceForMandate( from modules.interfaces.interfaceDbApp import getInterface as getAppInterface appInterface = getAppInterface(ctx.user, mandateId=targetMandateId) mandate = appInterface.getMandate(targetMandateId) - mandateName = mandate.get("name", "") if mandate else "" + mandateName = (mandate.get("label") or mandate.get("name", "")) if mandate else "" return BillingBalanceResponse( mandateId=targetMandateId, diff --git a/modules/routes/routeDataMandates.py b/modules/routes/routeDataMandates.py index 8d2c4a2b..b74b6277 100644 --- a/modules/routes/routeDataMandates.py +++ b/modules/routes/routeDataMandates.py @@ -192,7 +192,7 @@ def create_mandate( ) # Get optional fields with defaults - description = mandateData.get('description') + label = mandateData.get('label') enabled = mandateData.get('enabled', True) appInterface = interfaceDbApp.getRootInterface() @@ -200,7 +200,7 @@ def create_mandate( # Create mandate newMandate = appInterface.createMandate( name=name, - description=description, + label=label, enabled=enabled ) diff --git a/modules/routes/routeInvitations.py b/modules/routes/routeInvitations.py index 01c395e2..2454db44 100644 --- a/modules/routes/routeInvitations.py +++ b/modules/routes/routeInvitations.py @@ -190,7 +190,7 @@ def create_invitation( from modules.connectors.connectorMessagingEmail import ConnectorMessagingEmail # Get mandate name for the email mandate = rootInterface.getMandate(str(context.mandateId)) - mandateName = mandate.name if mandate else "PowerOn" + mandateName = (mandate.label or mandate.name) if mandate else "PowerOn" emailConnector = ConnectorMessagingEmail() emailSubject = f"Einladung zu {mandateName}" @@ -249,7 +249,7 @@ def create_invitation( # Get mandate name for notification mandate = rootInterface.getMandate(str(context.mandateId)) - mandateName = mandate.mandateLabel if mandate and mandate.mandateLabel else "PowerOn" + mandateName = (mandate.label or mandate.name) if mandate else "PowerOn" inviterName = context.user.fullName or context.user.username createInvitationNotification( @@ -529,7 +529,7 @@ def validate_invitation( # Get mandate name mandate = rootInterface.getMandate(str(mandateId)) if mandateId else None if mandate: - mandateName = mandate.name + mandateName = mandate.label or mandate.name # Get role names roleIds = invitation.roleIds or [] diff --git a/modules/routes/routeSystem.py b/modules/routes/routeSystem.py index 3c8cdd3d..83f4ed85 100644 --- a/modules/routes/routeSystem.py +++ b/modules/routes/routeSystem.py @@ -153,7 +153,7 @@ def _buildDynamicBlock( mandateId = str(instance.mandateId) if mandateId not in mandatesMap: mandate = rootInterface.getMandate(mandateId) - mandateName = mandate.name if mandate and hasattr(mandate, 'name') else mandateId + mandateName = (mandate.label or mandate.name) if mandate else mandateId mandatesMap[mandateId] = { "id": mandateId, "uiLabel": mandateName,