diff --git a/app.py b/app.py index 9aa05093..0acdcfce 100644 --- a/app.py +++ b/app.py @@ -312,6 +312,39 @@ async def lifespan(app: FastAPI): # Register audit log cleanup scheduler from modules.shared.auditLogger import registerAuditLogCleanupScheduler registerAuditLogCleanupScheduler() + + # Ensure billing settings and accounts exist + try: + from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface + from modules.datamodels.datamodelBilling import BillingSettings, BillingModelEnum + + billingInterface = getBillingRootInterface() + + # Ensure root mandate has billing settings + rootMandate = rootInterface.getRootMandate() + if rootMandate: + rootMandateId = rootMandate.get("id") if isinstance(rootMandate, dict) else getattr(rootMandate, "id", None) + if rootMandateId: + existingSettings = billingInterface.getSettings(rootMandateId) + if not existingSettings: + settings = BillingSettings( + mandateId=rootMandateId, + billingModel=BillingModelEnum.PREPAY_USER, + defaultUserCredit=10.0, + warningThresholdPercent=10.0, + blockOnZeroBalance=True, + notifyOnWarning=True + ) + billingInterface.createSettings(settings) + logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + + # Efficient bulk check: Ensure all users have billing accounts (3 queries total) + accountsCreated = billingInterface.ensureAllUserAccountsExist() + if accountsCreated > 0: + logger.info(f"Billing startup: Created {accountsCreated} missing user accounts") + + except Exception as e: + logger.warning(f"Failed to ensure billing settings/accounts (non-critical): {e}") yield diff --git a/modules/aicore/aicorePluginPrivateLlm.py b/modules/aicore/aicorePluginPrivateLlm.py index 84a5a6b4..dfcb7eaf 100644 --- a/modules/aicore/aicorePluginPrivateLlm.py +++ b/modules/aicore/aicorePluginPrivateLlm.py @@ -216,12 +216,14 @@ class AiPrivateLlm(BaseConnectorAi): availableOllamaModels = availability.get("availableModels", []) # Define all models with their Ollama backend names - # Actual model specs (for 32GB RAM server): - # - qwen2.5:7b: 7.6B params, 128K context, ~4.7GB RAM (Text) - # - qwen2.5vl:7b: 8.29B params, 125K context, ~6GB RAM (Vision) - # - granite3.2-vision: 2B params, 16K context, ~2.4GB RAM (Vision) + # Actual model specs (for 31GB RAM + 22GB GPU server): + # Context sizes reduced to fit in available RAM + # - qwen2.5:7b: 7.6B params, ~4.7GB RAM (Text) - 8K context + # - qwen2.5vl:7b: 8.29B params, ~6GB RAM (Vision) - 4K context + # - granite3.2-vision: 2B params, ~2.4GB RAM (Vision) - 4K context + # - deepseek-ocr: ~6.7GB RAM (OCR) - 4K context modelDefinitions = [ - # Text Model (qwen2.5:7b: 7.6B, 128K context) + # Text Model (qwen2.5:7b: 7.6B) { "model": AiModel( name="poweron-text-general", @@ -229,8 +231,8 @@ class AiPrivateLlm(BaseConnectorAi): connectorType="privatellm", apiUrl=f"{self.baseUrl}/api/analyze", temperature=0.1, - maxTokens=8192, - contextLength=128000, # qwen2.5:7b actual context: 128K + maxTokens=4096, + contextLength=8192, # Reduced for RAM constraints costPer1kTokensInput=0.0, # Flat rate pricing costPer1kTokensOutput=0.0, # Flat rate pricing speedRating=8, # Fast and efficient @@ -247,7 +249,7 @@ class AiPrivateLlm(BaseConnectorAi): ), "ollamaModel": "qwen2.5:7b" }, - # Vision General Model (qwen2.5vl:7b: 8.29B, 125K context) + # Vision General Model (qwen2.5vl:7b: 8.29B) { "model": AiModel( name="poweron-vision-general", @@ -255,8 +257,8 @@ class AiPrivateLlm(BaseConnectorAi): connectorType="privatellm", apiUrl=f"{self.baseUrl}/api/analyze", temperature=0.2, - maxTokens=8192, - contextLength=125000, # qwen2.5vl:7b actual context: 125K + maxTokens=2048, + contextLength=4096, # Reduced for RAM constraints (vision needs more) costPer1kTokensInput=0.0, # Flat rate pricing costPer1kTokensOutput=0.0, # Flat rate pricing speedRating=7, @@ -273,7 +275,7 @@ class AiPrivateLlm(BaseConnectorAi): ), "ollamaModel": "qwen2.5vl:7b" }, - # Vision Deep Model (granite3.2-vision: 2B, 16K context) + # Vision Deep Model (granite3.2-vision: 2B) { "model": AiModel( name="poweron-vision-deep", @@ -281,8 +283,8 @@ class AiPrivateLlm(BaseConnectorAi): connectorType="privatellm", apiUrl=f"{self.baseUrl}/api/analyze", temperature=0.1, - maxTokens=4096, - contextLength=16000, # granite3.2-vision actual context: 16K + maxTokens=2048, + contextLength=4096, # Reduced for RAM constraints costPer1kTokensInput=0.0, # Flat rate pricing costPer1kTokensOutput=0.0, # Flat rate pricing speedRating=9, # Fast due to small 2B model diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index 02f80762..b1e73ae0 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -301,6 +301,7 @@ registerModelLabels( class ChatWorkflow(BaseModel): """Chat workflow container. User-owned, no mandate context.""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) + featureInstanceId: Optional[str] = Field(None, description="Feature instance ID for multi-tenancy isolation", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) status: str = Field(default="running", description="Current status of the workflow", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ {"value": "running", "label": {"en": "Running", "fr": "En cours"}}, {"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}}, @@ -374,6 +375,7 @@ registerModelLabels( {"en": "Chat Workflow", "fr": "Flux de travail de chat"}, { "id": {"en": "ID", "fr": "ID"}, + "featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"}, "status": {"en": "Status", "fr": "Statut"}, "name": {"en": "Name", "fr": "Nom"}, "currentRound": {"en": "Current Round", "fr": "Tour actuel"}, @@ -399,7 +401,8 @@ class UserInputRequest(BaseModel): listFileId: List[str] = Field(default_factory=list, description="List of file IDs") userLanguage: str = Field(default="en", description="User's preferred language") workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue") - preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai')") + preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders") + preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)") registerModelLabels( diff --git a/modules/features/chatplayground/routeFeatureChatplayground.py b/modules/features/chatplayground/routeFeatureChatplayground.py index 6a76e70e..cedd05d6 100644 --- a/modules/features/chatplayground/routeFeatureChatplayground.py +++ b/modules/features/chatplayground/routeFeatureChatplayground.py @@ -131,7 +131,7 @@ async def stop_workflow( # Validate access and get mandate ID mandateId = await _validateInstanceAccess(instanceId, context) - # Stop workflow + # Stop workflow (pass featureInstanceId for proper RBAC filtering) workflow = await chatStop( context.user, workflowId, diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index d6f0f063..99cca1a2 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -1273,34 +1273,57 @@ def initRootMandateBilling(mandateId: str) -> None: """ Initialize billing settings for root mandate. Root mandate uses PREPAY_USER model with 10 CHF initial credit per user. + Also creates billing accounts for all users of the mandate. Args: mandateId: Root mandate ID """ try: from modules.interfaces.interfaceDbBilling import _getRootInterface + from modules.interfaces.interfaceDbApp import getRootInterface as getAppRootInterface from modules.datamodels.datamodelBilling import BillingSettings, BillingModelEnum billingInterface = _getRootInterface() + appInterface = getAppRootInterface() # Check if settings already exist existingSettings = billingInterface.getSettings(mandateId) if existingSettings: logger.info("Billing settings for root mandate already exist") - return + else: + # Create billing settings for root mandate + settings = BillingSettings( + mandateId=mandateId, + billingModel=BillingModelEnum.PREPAY_USER, + defaultUserCredit=10.0, # 10 CHF initial credit per user + warningThresholdPercent=10.0, + blockOnZeroBalance=True, + notifyOnWarning=True + ) + + billingInterface.createSettings(settings) + logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + existingSettings = billingInterface.getSettings(mandateId) - # Create billing settings for root mandate - settings = BillingSettings( - mandateId=mandateId, - billingModel=BillingModelEnum.PREPAY_USER, - defaultUserCredit=10.0, # 10 CHF initial credit per user - warningThresholdPercent=10.0, - blockOnZeroBalance=True, - notifyOnWarning=True - ) - - billingInterface.createSettings(settings) - logger.info(f"Created billing settings for root mandate: PREPAY_USER with 10 CHF default credit") + # Create billing accounts for all users of the mandate + if existingSettings: + billingModel = existingSettings.get("billingModel", "UNLIMITED") + if billingModel == BillingModelEnum.PREPAY_USER.value: + defaultCredit = existingSettings.get("defaultUserCredit", 10.0) + userMandates = appInterface.getUserMandatesByMandate(mandateId) + accountsCreated = 0 + + for um in userMandates: + userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None) + if userId: + existingAccount = billingInterface.getUserAccount(mandateId, userId) + if not existingAccount: + billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) + accountsCreated += 1 + logger.debug(f"Created billing account for user {userId}") + + if accountsCreated > 0: + logger.info(f"Created {accountsCreated} billing accounts for root mandate users with {defaultCredit} CHF each") except Exception as e: # Don't fail bootstrap if billing init fails diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 2a872bce..1c082d33 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -1608,6 +1608,7 @@ class AppObjects: def createUserMandate(self, userId: str, mandateId: str, roleIds: List[str] = None) -> UserMandate: """ Create a UserMandate record (add user to mandate). + Also creates a billing account for the user if billing is configured for PREPAY_USER. Args: userId: User ID @@ -1641,11 +1642,45 @@ class AppObjects: ) self.db.recordCreate(UserMandateRole, userMandateRole.model_dump()) + # Create billing account for user if billing is configured + self._ensureUserBillingAccount(userId, mandateId) + cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")} return UserMandate(**cleanedRecord) except Exception as e: logger.error(f"Error creating UserMandate: {e}") raise ValueError(f"Failed to create UserMandate: {e}") + + def _ensureUserBillingAccount(self, userId: str, mandateId: str) -> None: + """ + Ensure a user has a billing account for the mandate if billing is configured. + Creates account with default credit from settings if billingModel is PREPAY_USER. + + Args: + userId: User ID + mandateId: Mandate ID + """ + try: + from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface + from modules.datamodels.datamodelBilling import BillingModelEnum + + billingInterface = getBillingRootInterface() + settings = billingInterface.getSettings(mandateId) + + if not settings: + return # No billing configured for this mandate + + billingModel = settings.get("billingModel", "UNLIMITED") + if billingModel != BillingModelEnum.PREPAY_USER.value: + return # Only create user accounts for PREPAY_USER model + + defaultCredit = settings.get("defaultUserCredit", 10.0) + billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) + logger.info(f"Created billing account for user {userId} in mandate {mandateId} with {defaultCredit} CHF") + + except Exception as e: + # Don't fail user mandate creation if billing account creation fails + logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}") def deleteUserMandate(self, userId: str, mandateId: str) -> bool: """ diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index 141fc118..b3df9bea 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -360,6 +360,98 @@ class BillingObjects: return created + def ensureAllUserAccountsExist(self) -> int: + """ + Efficiently ensure all users across all mandates have billing accounts. + Uses bulk queries to minimize database connections. + + Returns: + Number of accounts created + """ + from modules.interfaces.interfaceDbApp import getRootInterface as getAppRootInterface + + try: + appInterface = getAppRootInterface() + accountsCreated = 0 + + # Step 1: Get all billing settings in one query (only PREPAY_USER mandates need user accounts) + allSettings = self.db.getRecordset(BillingSettings) + prepayUserMandates = {} + for s in allSettings: + if s.get("billingModel") == BillingModelEnum.PREPAY_USER.value: + prepayUserMandates[s.get("mandateId")] = s.get("defaultUserCredit", 10.0) + + if not prepayUserMandates: + logger.debug("No PREPAY_USER mandates found, skipping account check") + return 0 + + # Step 2: Get all existing USER accounts in one query + allAccounts = self.db.getRecordset( + BillingAccount, + recordFilter={"accountType": AccountTypeEnum.USER.value} + ) + # Build set of existing (mandateId, userId) pairs + existingAccountKeys = set() + for acc in allAccounts: + key = (acc.get("mandateId"), acc.get("userId")) + existingAccountKeys.add(key) + + # Step 3: Get all user-mandate combinations in one query + allUserMandates = appInterface.db.getRecordset( + appInterface.db.getModel("UserMandate"), + recordFilter={"enabled": True} + ) + + # Step 4: Find missing accounts and create them + for um in allUserMandates: + mandateId = um.get("mandateId") + userId = um.get("userId") + + if not mandateId or not userId: + continue + + # Only process mandates with PREPAY_USER billing + if mandateId not in prepayUserMandates: + continue + + # Check if account already exists (in memory, no DB call) + key = (mandateId, userId) + if key in existingAccountKeys: + continue + + # Create missing account + defaultCredit = prepayUserMandates[mandateId] + account = BillingAccount( + mandateId=mandateId, + userId=userId, + accountType=AccountTypeEnum.USER, + balance=defaultCredit, + enabled=True + ) + created = self.createAccount(account) + + # Create initial credit transaction + if defaultCredit > 0: + self.createTransaction(BillingTransaction( + accountId=created["id"], + transactionType=TransactionTypeEnum.CREDIT, + amount=defaultCredit, + description="Initial credit for new user", + referenceType=ReferenceTypeEnum.SYSTEM + )) + + existingAccountKeys.add(key) # Track newly created + accountsCreated += 1 + + if accountsCreated > 0: + logger.info(f"Created {accountsCreated} missing billing accounts") + + return accountsCreated + + except Exception as e: + logger.error(f"Error ensuring user accounts exist: {e}") + return 0 + # ========================================================================= # BillingTransaction Operations # ========================================================================= @@ -502,11 +594,16 @@ class BillingObjects: # Get the relevant account if billingModel == BillingModelEnum.PREPAY_USER: account = self.getUserAccount(mandateId, userId) + # Auto-create user account if not exists (with default credit from settings) + if not account: + defaultCredit = settings.get("defaultUserCredit", 10.0) + logger.info(f"Auto-creating billing account for user {userId} in mandate {mandateId} with {defaultCredit} CHF initial credit") + account = self.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit) else: account = self.getMandateAccount(mandateId) if not account: - # No account = no balance = potentially blocked + # No account (only happens for mandate-level accounts) = potentially blocked if settings.get("blockOnZeroBalance", True): return BillingCheckResult( allowed=False, @@ -713,11 +810,18 @@ class BillingObjects: userMandates = appInterface.getUserMandates(userId) for um in userMandates: - mandateId = um.get("mandateId") + # Handle both Pydantic models and dicts + mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) + if not mandateId: + continue + mandate = appInterface.getMandate(mandateId) if not mandate: continue + # Get mandate name (handle both Pydantic and dict) + mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "") + settings = self.getSettings(mandateId) if not settings: continue @@ -740,7 +844,7 @@ class BillingObjects: balances.append(BillingBalanceResponse( mandateId=mandateId, - mandateName=mandate.get("name", ""), + mandateName=mandateName, billingModel=billingModel, balance=balance, warningThreshold=warningThreshold, diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index e3698509..ffeea594 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -123,7 +123,7 @@ async def getBalance( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, ctx.mandateId, featureCode="billing" ) @@ -148,7 +148,7 @@ async def getBalanceForMandate( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, targetMandateId, featureCode="billing" ) @@ -158,7 +158,7 @@ async def getBalanceForMandate( # Get mandate name from app interface from modules.interfaces.interfaceDbApp import getInterface as getAppInterface - appInterface = getAppInterface(ctx.currentUser, mandateId=targetMandateId) + appInterface = getAppInterface(ctx.user, mandateId=targetMandateId) mandate = appInterface.getMandate(targetMandateId) mandateName = mandate.get("name", "") if mandate else "" @@ -190,7 +190,7 @@ async def getTransactions( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, ctx.mandateId, featureCode="billing" ) @@ -240,7 +240,7 @@ async def getStatistics( if period == "day" and not month: raise HTTPException(status_code=400, detail="Month is required for 'day' period") - billingInterface = getBillingInterface(ctx.currentUser, ctx.mandateId) + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) settings = billingInterface.getSettings(ctx.mandateId) if not settings: @@ -256,7 +256,7 @@ async def getStatistics( # Get the relevant account if billingModel == BillingModelEnum.PREPAY_USER: - account = billingInterface.getUserAccount(ctx.mandateId, ctx.currentUser.id) + account = billingInterface.getUserAccount(ctx.mandateId, ctx.user.id) else: account = billingInterface.getMandateAccount(ctx.mandateId) @@ -316,7 +316,7 @@ async def getAllowedProviders( """ try: billingService = getBillingService( - ctx.currentUser, + ctx.user, ctx.mandateId, featureCode="billing" ) @@ -344,7 +344,7 @@ async def getSettingsAdmin( Get billing settings for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) settings = billingInterface.getSettings(targetMandateId) if not settings: @@ -372,7 +372,7 @@ async def createOrUpdateSettings( Create or update billing settings for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) existingSettings = billingInterface.getSettings(targetMandateId) if existingSettings: @@ -421,7 +421,7 @@ async def addCredit( """ try: # Get settings to determine billing model - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) settings = billingInterface.getSettings(targetMandateId) if not settings: @@ -482,7 +482,7 @@ async def getAccounts( Get all billing accounts for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) # Get all accounts for this mandate via interface accounts = billingInterface.getAccountsByMandate(targetMandateId) @@ -507,6 +507,70 @@ async def getAccounts( raise HTTPException(status_code=500, detail=str(e)) +class MandateUserSummary(BaseModel): + """Summary of a user for billing admin purposes.""" + id: str + email: Optional[str] = None + firstName: Optional[str] = None + lastName: Optional[str] = None + displayName: Optional[str] = None + + +@router.get("/admin/users/{targetMandateId}", response_model=List[MandateUserSummary]) +@limiter.limit("30/minute") +async def getUsersForMandate( + request: Request, + targetMandateId: str = Path(..., description="Mandate ID"), + ctx: RequestContext = Depends(getRequestContext), + _admin = Depends(requireSysAdmin) +): + """ + Get all users belonging to a mandate (SysAdmin only). + Used by billing admin to select users for credit assignment. + """ + try: + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + + appInterface = getAppInterface(ctx.user, mandateId=targetMandateId) + userMandates = appInterface.getUserMandatesByMandate(targetMandateId) + + result = [] + for um in userMandates: + userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None) + if not userId: + continue + + user = appInterface.getUser(userId) + if not user: + continue + + # Handle both Pydantic models and dicts + if isinstance(user, dict): + firstName = user.get("firstName", "") + lastName = user.get("lastName", "") + email = user.get("email", "") + else: + firstName = getattr(user, "firstName", "") or "" + lastName = getattr(user, "lastName", "") or "" + email = getattr(user, "email", "") or "" + + displayName = f"{firstName} {lastName}".strip() or email or userId + + result.append(MandateUserSummary( + id=userId, + email=email, + firstName=firstName, + lastName=lastName, + displayName=displayName + )) + + return result + + except Exception as e: + logger.error(f"Error getting users for mandate {targetMandateId}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse]) @limiter.limit("30/minute") async def getTransactionsAdmin( @@ -520,7 +584,7 @@ async def getTransactionsAdmin( Get all transactions for a mandate (SysAdmin only). """ try: - billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) + billingInterface = getBillingInterface(ctx.user, targetMandateId) transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=limit) result = [] diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py index 81d83022..3d2f5cba 100644 --- a/modules/services/serviceAi/mainServiceAi.py +++ b/modules/services/serviceAi/mainServiceAi.py @@ -184,16 +184,17 @@ class AiService: ) logger.debug(f"Automation provider check passed: {effectiveProviders}") - # Check if preferred provider (from UI selection) is allowed - preferredProvider = getattr(self.services, 'preferredProvider', None) - if preferredProvider: - if preferredProvider not in rbacAllowedProviders: - logger.warning(f"Preferred provider {preferredProvider} not allowed for user {user.id}") - raise ProviderNotAllowedException( - provider=preferredProvider, - message=f"Der gewählte Provider '{preferredProvider}' ist für Ihre Rolle nicht freigegeben." - ) - logger.debug(f"Preferred provider {preferredProvider} is allowed") + # Check if preferred providers (from UI multiselect) are allowed + preferredProviders = getattr(self.services, 'preferredProviders', None) + if preferredProviders: + for provider in preferredProviders: + if provider not in rbacAllowedProviders: + logger.warning(f"Preferred provider {provider} not allowed for user {user.id}") + raise ProviderNotAllowedException( + provider=provider, + message=f"Der gewählte Provider '{provider}' ist für Ihre Rolle nicht freigegeben." + ) + logger.debug(f"All preferred providers are allowed: {preferredProviders}") logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed") diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index 99a89df1..6a0a00e4 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -42,10 +42,14 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode try: services = getServices(currentUser, mandateId=mandateId) - # Store preferred provider in services context for billing/model selection - if hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: - services.preferredProvider = userInput.preferredProvider - logger.debug(f"Using preferred provider: {userInput.preferredProvider}") + # Store preferred providers in services context for billing/model selection + # Support both preferredProviders (list) and legacy preferredProvider (string) + if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders: + services.preferredProviders = userInput.preferredProviders + logger.debug(f"Using preferred providers: {userInput.preferredProviders}") + elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: + services.preferredProviders = [userInput.preferredProvider] + logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}") # Store feature instance ID in services context if featureInstanceId: @@ -59,10 +63,14 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode logger.error(f"Error starting chat: {str(e)}") raise -async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None) -> ChatWorkflow: +async def chatStop(currentUser: User, workflowId: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow: """Stops a running chat.""" try: services = getServices(currentUser, mandateId=mandateId) + # Store feature instance ID in services context for proper RBAC filtering + if featureInstanceId: + services.featureInstanceId = featureInstanceId + services.featureCode = 'chatplayground' workflowManager = WorkflowManager(services) return await workflowManager.workflowStop(workflowId) except Exception as e: diff --git a/modules/workflows/workflowManager.py b/modules/workflows/workflowManager.py index d0c35bf1..b15c66b7 100644 --- a/modules/workflows/workflowManager.py +++ b/modules/workflows/workflowManager.py @@ -97,6 +97,7 @@ class WorkflowManager: "totalTasks": 0, "totalActions": 0, "mandateId": self.services.mandateId, + "featureInstanceId": getattr(self.services, 'featureInstanceId', None), # Feature instance ID for isolation "messageIds": [], "workflowMode": workflowMode, "maxSteps": 10 , # Set maxSteps