diff --git a/modules/datamodels/datamodelAi.py b/modules/datamodels/datamodelAi.py index c9d81bfa..4233b7d7 100644 --- a/modules/datamodels/datamodelAi.py +++ b/modules/datamodels/datamodelAi.py @@ -144,6 +144,9 @@ class AiCallOptions(BaseModel): temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)") maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch") + # Provider filtering (from UI multiselect or automation config) + allowedProviders: Optional[List[str]] = Field(default=None, description="List of allowed AI providers to use (empty = all RBAC-permitted)") + class AiCallRequest(BaseModel): """Centralized AI call request payload for interface use.""" diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index e2d631e8..fbad3d57 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -403,8 +403,7 @@ class UserInputRequest(BaseModel): listFileId: List[str] = Field(default_factory=list, description="List of file IDs") userLanguage: str = Field(default="en", description="User's preferred language") workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue") - preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders") - preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)") + allowedProviders: Optional[List[str]] = Field(None, description="List of allowed AI providers (multiselect)") registerModelLabels( diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py index 2e6e36f5..b2d91ed0 100644 --- a/modules/interfaces/interfaceAiObjects.py +++ b/modules/interfaces/interfaceAiObjects.py @@ -89,6 +89,17 @@ class AiObjects: # Get failover models for this operation type availableModels = modelRegistry.getAvailableModels() + + # Filter by allowedProviders if specified (from workflow config) + allowedProviders = getattr(options, 'allowedProviders', None) if options else None + if allowedProviders: + filteredModels = [m for m in availableModels if m.connectorType in allowedProviders] + if filteredModels: + logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})") + availableModels = filteredModels + else: + logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models") + failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels) if not failoverModelList: diff --git a/modules/interfaces/interfaceDbChat.py b/modules/interfaces/interfaceDbChat.py index 0a3971a8..9ee20fc0 100644 --- a/modules/interfaces/interfaceDbChat.py +++ b/modules/interfaces/interfaceDbChat.py @@ -1532,8 +1532,19 @@ class ChatObjects: return [] # Return all stats records sorted by creation time - stats.sort(key=lambda x: x.get("created_at", "")) - return [ChatStat(**stat) for stat in stats] + # DB uses _createdAt (camelCase system field) + stats.sort(key=lambda x: x.get("_createdAt", 0)) + + # Convert to ChatStat objects, preserving _createdAt via extra="allow" + result = [] + for stat in stats: + chat_stat = ChatStat(**stat) + # Explicitly preserve _createdAt from raw DB record + if "_createdAt" in stat: + setattr(chat_stat, '_createdAt', stat["_createdAt"]) + result.append(chat_stat) + + return result def createStat(self, statData: Dict[str, Any]) -> ChatStat: @@ -1549,9 +1560,16 @@ class ChatObjects: # Validate the stat data against ChatStat model stat = ChatStat(**statData) + logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: " + f"process={statData.get('process')}, " + f"priceCHF={statData.get('priceCHF', 0):.4f}, " + f"processingTime={statData.get('processingTime', 0):.2f}s") + # Create the stat record in the database created = self.db.recordCreate(ChatStat, stat) + logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}") + # Return the created ChatStat return ChatStat(**created) except Exception as e: @@ -1643,10 +1661,14 @@ class ChatObjects: if afterTimestamp is not None and stat_timestamp <= afterTimestamp: continue + # Convert to dict and include _createdAt for frontend + stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict() + stat_dict['_createdAt'] = stat_timestamp + items.append({ "type": "stat", "createdAt": stat_timestamp, - "item": stat + "item": stat_dict }) # Sort all items by createdAt timestamp for chronological order diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py index 3d2f5cba..5fdf32a5 100644 --- a/modules/services/serviceAi/mainServiceAi.py +++ b/modules/services/serviceAi/mainServiceAi.py @@ -94,15 +94,30 @@ class AiService: Includes billing checks: - Balance check before AI call - Provider permission check (via RBAC) + + Also stores workflow stats after each successful AI call. """ - # Billing check before AI call + # Billing check before AI call (validates RBAC permissions) await self._checkBillingBeforeAiCall() + # Calculate effective allowedProviders: RBAC ∩ Workflow + # RBAC is master - only RBAC-permitted providers can ever be used + effectiveProviders = self._calculateEffectiveProviders() + if effectiveProviders and request.options: + request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders}) + logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}") + if hasattr(request, 'contentParts') and request.contentParts: - return await self.extractionService.processContentPartsWithAi( + response = await self.extractionService.processContentPartsWithAi( request, self.aiObjects, progressCallback ) - return await self.aiObjects.callWithTextContext(request) + else: + response = await self.aiObjects.callWithTextContext(request) + + # Store workflow stats after each AI call + self._storeAiCallStats(response, request) + + return response async def _checkBillingBeforeAiCall(self) -> None: """ @@ -206,6 +221,87 @@ class AiService: # Log but don't block on billing check errors logger.warning(f"Billing check failed with error (non-blocking): {e}") + def _calculateEffectiveProviders(self) -> Optional[List[str]]: + """ + Calculate effective allowed providers: RBAC ∩ Workflow. + + RBAC is master - only RBAC-permitted providers can ever be used. + If workflow specifies allowedProviders, intersect with RBAC. + If no workflow providers, use all RBAC-permitted providers. + + Returns: + List of effective allowed providers, or None if no filtering needed + """ + try: + user = getattr(self.services, 'user', None) + mandateId = getattr(self.services, 'mandateId', None) + + if not user or not mandateId: + return None + + # Get RBAC-permitted providers (master list) + # Note: getBillingService is imported at module level from mainServiceBilling + featureInstanceId = getattr(self.services, 'featureInstanceId', None) + featureCode = getattr(self.services, 'featureCode', None) + billingService = getBillingService(user, mandateId, featureInstanceId, featureCode) + rbacProviders = billingService.getallowedProviders() + + if not rbacProviders: + logger.warning("No RBAC-permitted providers found") + return None + + # Get workflow-specified providers (optional filter) + workflowProviders = getattr(self.services, 'allowedProviders', None) + + if workflowProviders: + # Intersect: only providers that are both RBAC-permitted AND workflow-allowed + effectiveProviders = [p for p in workflowProviders if p in rbacProviders] + logger.debug(f"Provider filter: RBAC={rbacProviders}, Workflow={workflowProviders}, Effective={effectiveProviders}") + else: + # No workflow filter - use all RBAC-permitted providers + effectiveProviders = rbacProviders + logger.debug(f"Provider filter: RBAC={rbacProviders} (no workflow filter)") + + return effectiveProviders if effectiveProviders else None + + except Exception as e: + logger.warning(f"Error calculating effective providers: {e}") + return None + + def _storeAiCallStats(self, response, request: AiCallRequest) -> None: + """Store workflow stats after an AI call. + + This method stores the AI call statistics (cost, processing time, bytes) + to the workflow stats collection for tracking and billing purposes. + + Args: + response: AiCallResponse with cost/timing data + request: Original AiCallRequest for context + """ + try: + # Skip if no workflow context + workflow = getattr(self.services, 'workflow', None) + if not workflow or not hasattr(workflow, 'id') or not workflow.id: + logger.debug("No workflow context - skipping stats storage") + return + + # Skip if response is an error + if not response or getattr(response, 'errorCount', 0) > 0: + logger.debug("Error response - skipping stats storage") + return + + # Determine process name from operation type + opType = getattr(request.options, 'operationType', 'unknown') if request.options else 'unknown' + process = f"ai.call.{opType}" + + # Store the stat + self.services.chat.storeWorkflowStat(workflow, response, process) + logger.debug(f"Stored AI call stat: {process}, cost={getattr(response, 'priceCHF', 0):.4f} CHF") + + except Exception as e: + # Log but don't fail - stats storage is not critical + logger.debug(f"Could not store AI call stat: {str(e)}") + async def ensureAiObjectsInitialized(self): """Ensure aiObjects is initialized and submodules are ready.""" if self.aiObjects is None: @@ -428,7 +524,7 @@ Respond with ONLY a JSON object in this exact format: # Debug: persist prompt/response for analysis with context-specific naming debugPrefix = debugType if debugType else "plan" self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt") - response = await self.aiObjects.callWithTextContext(request) + response = await self.callAi(request) # Use callAi to ensure stats are stored result = response.content or "" self.services.utils.writeDebugFile(result, f"{debugPrefix}_response") return result @@ -485,16 +581,7 @@ Respond with ONLY a JSON object in this exact format: operationType=opType.value ) - # Try to store workflow stats, but don't fail if workflow is None (e.g., in chatbot context) - try: - self.services.chat.storeWorkflowStat( - self.services.workflow, - response, - f"ai.{opType.name.lower()}" - ) - except Exception as e: - # Log but don't fail - workflow might be None in some contexts (e.g., chatbot) - logger.debug(f"Could not store workflow stat (workflow may be None): {str(e)}") + # Note: Stats are now stored centrally in callAi() - no need to duplicate here self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed") self.services.chat.progressLogFinish(aiOperationId, True) diff --git a/modules/services/serviceAi/subAiCallLooping.py b/modules/services/serviceAi/subAiCallLooping.py index 5f3fb79f..2e4edc3e 100644 --- a/modules/services/serviceAi/subAiCallLooping.py +++ b/modules/services/serviceAi/subAiCallLooping.py @@ -269,17 +269,7 @@ class AiCallLooper: # Document generation - save all iteration responses self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}") - # Emit stats for this iteration (only if workflow exists and has id) - if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id: - try: - self.services.chat.storeWorkflowStat( - self.services.workflow, - response, - f"ai.call.{debugPrefix}.iteration_{iteration}" - ) - except Exception as statError: - # Don't break the main loop if stat storage fails - logger.warning(f"Failed to store workflow stat: {str(statError)}") + # Note: Stats are now stored centrally in callAi() - no need to duplicate here # Check for error response using generic error detection (errorCount > 0 or modelName == "error") if hasattr(response, 'errorCount') and response.errorCount > 0: diff --git a/modules/services/serviceGeneration/paths/imagePath.py b/modules/services/serviceGeneration/paths/imagePath.py index 1247494f..c61bc997 100644 --- a/modules/services/serviceGeneration/paths/imagePath.py +++ b/modules/services/serviceGeneration/paths/imagePath.py @@ -101,11 +101,7 @@ class ImageGenerationPath: operationType=OperationTypeEnum.IMAGE_GENERATE.value ) - self.services.chat.storeWorkflowStat( - self.services.workflow, - response, - "ai.generate.image" - ) + # Note: Stats are now stored centrally in callAi() - no need to duplicate here self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated") self.services.chat.progressLogFinish(imageOperationId, True) diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py index 6a0a00e4..e63f7932 100644 --- a/modules/workflows/automation/mainWorkflow.py +++ b/modules/workflows/automation/mainWorkflow.py @@ -42,14 +42,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode try: services = getServices(currentUser, mandateId=mandateId) - # Store preferred providers in services context for billing/model selection - # Support both preferredProviders (list) and legacy preferredProvider (string) - if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders: - services.preferredProviders = userInput.preferredProviders - logger.debug(f"Using preferred providers: {userInput.preferredProviders}") - elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider: - services.preferredProviders = [userInput.preferredProvider] - logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}") + # Store allowedProviders in services context for model selection + if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders: + services.allowedProviders = userInput.allowedProviders + logger.info(f"AI provider filter active: {userInput.allowedProviders}") # Store feature instance ID in services context if featureInstanceId: diff --git a/modules/workflows/processing/workflowProcessor.py b/modules/workflows/processing/workflowProcessor.py index 38763f51..a78a2270 100644 --- a/modules/workflows/processing/workflowProcessor.py +++ b/modules/workflows/processing/workflowProcessor.py @@ -14,7 +14,8 @@ from modules.workflows.processing.modes.modeDynamic import DynamicMode from modules.workflows.processing.modes.modeAutomation import AutomationMode from modules.workflows.processing.shared.stateTools import checkWorkflowStopped from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest -from modules.shared.jsonUtils import extractJsonString, repairBrokenJson +from modules.shared.jsonUtils import extractJsonString, repairBrokenJson, parseJsonWithModel +from modules.datamodels.datamodelWorkflow import UnderstandingResult if TYPE_CHECKING: from modules.datamodels.datamodelWorkflow import TaskResult @@ -477,8 +478,7 @@ class WorkflowProcessor: maxProcessingTime=15 # Fast path should complete in 15s ) - # Call AI directly (no document generation - just plain text response) - # Use callWithTextContext() for text-only calls + # Call AI via callAi() to ensure stats are stored aiRequest = AiCallRequest( prompt=fastPathPrompt, context="", @@ -486,7 +486,7 @@ class WorkflowProcessor: contentParts=None # Fast path doesn't process documents ) - aiCallResponse = await self.services.ai.aiObjects.callWithTextContext(aiRequest) + aiCallResponse = await self.services.ai.callAi(aiRequest) # Extract response content (AiCallResponse.content is a string) responseText = aiCallResponse.content if aiCallResponse.content else ""