diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index 8ba3ced1..02f80762 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -399,6 +399,7 @@ class UserInputRequest(BaseModel): listFileId: List[str] = Field(default_factory=list, description="List of file IDs") userLanguage: str = Field(default="en", description="User's preferred language") workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue") + preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai')") registerModelLabels( @@ -408,6 +409,7 @@ registerModelLabels( "prompt": {"en": "Prompt", "fr": "Invite"}, "listFileId": {"en": "File IDs", "fr": "IDs des fichiers"}, "userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"}, + "preferredProvider": {"en": "Preferred Provider", "fr": "Fournisseur préféré"}, }, ) diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py index a728bafc..81d83022 100644 --- a/modules/services/serviceAi/mainServiceAi.py +++ b/modules/services/serviceAi/mainServiceAi.py @@ -18,6 +18,11 @@ from modules.shared.jsonUtils import ( ) from .subJsonResponseHandling import JsonResponseHandler from modules.datamodels.datamodelAi import JsonAccumulationState +from modules.services.serviceBilling.mainServiceBilling import ( + getService as getBillingService, + InsufficientBalanceException, + ProviderNotAllowedException +) logger = logging.getLogger(__name__) @@ -85,12 +90,120 @@ class AiService: Replaces direct calls to self.aiObjects.call() to route content parts processing through serviceExtraction layer. + + Includes billing checks: + - Balance check before AI call + - Provider permission check (via RBAC) """ + # Billing check before AI call + await self._checkBillingBeforeAiCall() + if hasattr(request, 'contentParts') and request.contentParts: return await self.extractionService.processContentPartsWithAi( request, self.aiObjects, progressCallback ) return await self.aiObjects.callWithTextContext(request) + + async def _checkBillingBeforeAiCall(self) -> None: + """ + Check billing status before making an AI call. + + Verifies: + 1. User has sufficient balance (for prepay models) + 2. Provider is allowed for the user (via RBAC) + + Raises: + InsufficientBalanceException: If balance is insufficient + ProviderNotAllowedException: If provider is not allowed + """ + try: + # Get context from services + if not self.services: + logger.debug("No service center - skipping billing check") + return + + user = getattr(self.services, 'user', None) + mandateId = getattr(self.services, 'mandateId', None) + + if not user or not mandateId: + logger.debug("No user or mandate context - skipping billing check") + return + + # Get feature context + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + + # Get billing service + billingService = getBillingService( + user, + mandateId, + featureInstanceId, + featureCode + ) + + # Check balance (estimate typical AI call cost) + # We use a small estimate here; actual cost is recorded after the call + estimatedCost = 0.01 # ~1 cent CHF minimum + balanceCheck = billingService.checkBalance(estimatedCost) + + if not balanceCheck.allowed: + logger.warning( + f"Billing check failed for user {user.id}: " + f"Balance {balanceCheck.currentBalance:.2f} CHF, " + f"Reason: {balanceCheck.reason}" + ) + raise InsufficientBalanceException( + currentBalance=balanceCheck.currentBalance or 0.0, + requiredAmount=estimatedCost, + message=f"Ungenügendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}" + ) + + logger.debug(f"Billing check passed: Balance {balanceCheck.currentBalance:.2f} CHF") + + # Check if at least one provider is allowed (RBAC check) + rbacAllowedProviders = billingService.getallowedProviders() + if not rbacAllowedProviders: + logger.warning(f"No AI providers allowed for user {user.id} in mandate {mandateId}") + raise ProviderNotAllowedException( + provider="any", + message="Keine AI-Provider für Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator." + ) + + # Check automation-level allowedProviders restriction + automationAllowedProviders = getattr(self.services, 'allowedProviders', None) + if automationAllowedProviders: + # Filter by both RBAC and automation-level restrictions + effectiveProviders = [p for p in automationAllowedProviders if p in rbacAllowedProviders] + if not effectiveProviders: + logger.warning(f"No providers available after automation restriction. " + f"Automation allows: {automationAllowedProviders}, " + f"RBAC allows: {rbacAllowedProviders}") + raise ProviderNotAllowedException( + provider="any", + message="Die konfigurierten AI-Provider dieser Automation sind für Ihre Rolle nicht freigegeben." + ) + logger.debug(f"Automation provider check passed: {effectiveProviders}") + + # Check if preferred provider (from UI selection) is allowed + preferredProvider = getattr(self.services, 'preferredProvider', None) + if preferredProvider: + if preferredProvider not in rbacAllowedProviders: + logger.warning(f"Preferred provider {preferredProvider} not allowed for user {user.id}") + raise ProviderNotAllowedException( + provider=preferredProvider, + message=f"Der gewählte Provider '{preferredProvider}' ist für Ihre Rolle nicht freigegeben." + ) + logger.debug(f"Preferred provider {preferredProvider} is allowed") + + logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed") + + except InsufficientBalanceException: + raise # Re-raise billing exceptions + except ProviderNotAllowedException: + raise # Re-raise provider exceptions + except Exception as e: + # Log but don't block on billing check errors + logger.warning(f"Billing check failed with error (non-blocking): {e}") async def ensureAiObjectsInitialized(self): """Ensure aiObjects is initialized and submodules are ready.""" diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index e7c6839e..99a89df1 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -24,7 +24,7 @@ from .subAutomationUtils import parseScheduleToCron, planToPrompt, replacePlaceh logger = logging.getLogger(__name__) -async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None) -> ChatWorkflow: +async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow: """ Starts a new chat or continues an existing one, then launches processing asynchronously. @@ -34,12 +34,24 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode workflowId: Optional workflow ID to continue existing workflow workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution mandateId: Mandate ID from request context (required for proper data isolation) + featureInstanceId: Feature instance ID for context Example usage for Dynamic mode: workflow = await chatStart(currentUser, userInput, workflowMode=WorkflowModeEnum.WORKFLOW_DYNAMIC, mandateId=mandateId) """ try: services = getServices(currentUser, mandateId=mandateId) + + # Store preferred provider in services context for billing/model selection + if hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: + services.preferredProvider = userInput.preferredProvider + logger.debug(f"Using preferred provider: {userInput.preferredProvider}") + + # Store feature instance ID in services context + if featureInstanceId: + services.featureInstanceId = featureInstanceId + services.featureCode = 'chatplayground' + workflowManager = WorkflowManager(services) workflow = await workflowManager.workflowStart(userInput, workflowMode, workflowId) return workflow @@ -84,6 +96,14 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow: executionLog["messages"].append(f"Started execution at {executionStartTime}") + # Store allowed providers from automation in services context + if hasattr(automation, 'allowedProviders') and automation.allowedProviders: + services.allowedProviders = automation.allowedProviders + logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}") + + # Store feature context for billing + services.featureCode = 'automation' + # 2. Replace placeholders in template to generate plan template = automation.template or "" placeholders = automation.placeholders or {}