diff --git a/modules/features/chatbot/interfaceFeatureChatbot.py b/modules/features/chatbot/interfaceFeatureChatbot.py index 160addca..c3c99c1d 100644 --- a/modules/features/chatbot/interfaceFeatureChatbot.py +++ b/modules/features/chatbot/interfaceFeatureChatbot.py @@ -681,7 +681,10 @@ class ChatObjects: logs=logs, messages=messages, stats=stats, - mandateId=workflow.get("mandateId", self.mandateId) + mandateId=workflow.get("mandateId", self.mandateId), + featureInstanceId=workflow.get("featureInstanceId") or self.featureInstanceId or "", + workflowMode=workflow.get("workflowMode", "chatbot"), + maxSteps=workflow.get("maxSteps") if workflow.get("maxSteps") is not None else 1 ) except Exception as e: logger.error(f"Error validating workflow data: {str(e)}") @@ -706,6 +709,10 @@ class ChatObjects: if "featureInstanceId" not in workflowData or not workflowData["featureInstanceId"]: workflowData["featureInstanceId"] = self.featureInstanceId + # Ensure featureInstanceId is set (required field) + if not workflowData.get("featureInstanceId"): + workflowData["featureInstanceId"] = self.featureInstanceId or "" + # Use generic field separation based on ChatWorkflow model simpleFields, objectFields = self._separateObjectFields(ChatWorkflow, workflowData) @@ -729,6 +736,7 @@ class ChatObjects: messages=[], stats=[], mandateId=created.get("mandateId", self.mandateId), + featureInstanceId=created.get("featureInstanceId") or self.featureInstanceId or "", workflowMode=created["workflowMode"], maxSteps=created.get("maxSteps", 1) ) @@ -775,7 +783,10 @@ class ChatObjects: logs=logs, messages=messages, stats=stats, - mandateId=updated.get("mandateId", workflow.mandateId) + mandateId=updated.get("mandateId", workflow.mandateId), + featureInstanceId=updated.get("featureInstanceId") or workflow.featureInstanceId or self.featureInstanceId or "", + workflowMode=updated.get("workflowMode") if updated.get("workflowMode") is not None else (workflow.workflowMode if hasattr(workflow, 'workflowMode') and workflow.workflowMode else "chatbot"), + maxSteps=updated.get("maxSteps") if updated.get("maxSteps") is not None else (workflow.maxSteps if hasattr(workflow, 'maxSteps') and workflow.maxSteps is not None else 1) ) def deleteWorkflow(self, workflowId: str) -> bool: @@ -881,7 +892,9 @@ class ChatObjects: "taskNumber": msg.get("taskNumber"), "actionNumber": msg.get("actionNumber"), "taskProgress": msg.get("taskProgress"), - "actionProgress": msg.get("actionProgress") + "actionProgress": msg.get("actionProgress"), + "mandateId": msg.get("mandateId") or self.mandateId or "", + "featureInstanceId": msg.get("featureInstanceId") or self.featureInstanceId or "" }) # Apply default sorting by publishedAt if no sort specified @@ -924,7 +937,9 @@ class ChatObjects: taskNumber=msg.get("taskNumber"), actionNumber=msg.get("actionNumber"), taskProgress=msg.get("taskProgress"), - actionProgress=msg.get("actionProgress") + actionProgress=msg.get("actionProgress"), + mandateId=msg.get("mandateId") or self.mandateId or "", + featureInstanceId=msg.get("featureInstanceId") or self.featureInstanceId or "" ) chat_messages.append(chat_message) @@ -1095,7 +1110,9 @@ class ChatObjects: success=createdMessage.get("success"), actionId=createdMessage.get("actionId"), actionMethod=createdMessage.get("actionMethod"), - actionName=createdMessage.get("actionName") + actionName=createdMessage.get("actionName"), + mandateId=createdMessage.get("mandateId") or self.mandateId or "", + featureInstanceId=createdMessage.get("featureInstanceId") or self.featureInstanceId or "" ) # Emit message event for streaming (if event manager is available) @@ -1318,7 +1335,8 @@ class ChatObjects: """Returns documents for a message from normalized table.""" try: documents = getRecordsetWithRBAC(self.db, ChatDocument, self.currentUser, recordFilter={"messageId": messageId}) - return [ChatDocument(**doc) for doc in documents] + # Ensure mandateId and featureInstanceId are set for each document + return [ChatDocument(**{**doc, "mandateId": doc.get("mandateId") or self.mandateId or "", "featureInstanceId": doc.get("featureInstanceId") or self.featureInstanceId or ""}) for doc in documents] except Exception as e: logger.error(f"Error getting message documents: {str(e)}") return [] @@ -1338,7 +1356,13 @@ class ChatObjects: created = self.db.recordCreate(ChatDocument, document.model_dump()) if created: - created_doc = ChatDocument(**created) + # Ensure mandateId and featureInstanceId are set + doc_dict = dict(created) + if "mandateId" not in doc_dict or not doc_dict["mandateId"]: + doc_dict["mandateId"] = self.mandateId or "" + if "featureInstanceId" not in doc_dict or not doc_dict["featureInstanceId"]: + doc_dict["featureInstanceId"] = self.featureInstanceId or "" + created_doc = ChatDocument(**doc_dict) logger.debug(f"Successfully created document in database: {created_doc.fileName} (id: {created_doc.id})") return created_doc else: @@ -1392,7 +1416,8 @@ class ChatObjects: "agentName": log.get("agentName"), "status": log.get("status"), "progress": log.get("progress"), - "mandateId": log.get("mandateId"), + "mandateId": log.get("mandateId") or self.mandateId or "", + "featureInstanceId": log.get("featureInstanceId") or self.featureInstanceId or "", "userId": log.get("userId") }) @@ -1410,7 +1435,8 @@ class ChatObjects: # If no pagination requested, return all items if pagination is None: - return [ChatLog(**log) for log in logDicts] + # Ensure mandateId and featureInstanceId are set for each log + return [ChatLog(**{**log, "mandateId": log.get("mandateId") or self.mandateId or "", "featureInstanceId": log.get("featureInstanceId") or self.featureInstanceId or ""}) for log in logDicts] # Count total items after filters totalItems = len(logDicts) @@ -1422,7 +1448,8 @@ class ChatObjects: pagedLogDicts = logDicts[startIdx:endIdx] # Convert to model objects - items = [ChatLog(**log) for log in pagedLogDicts] + # Ensure mandateId and featureInstanceId are set for each log + items = [ChatLog(**{**log, "mandateId": log.get("mandateId") or self.mandateId or "", "featureInstanceId": log.get("featureInstanceId") or self.featureInstanceId or ""}) for log in pagedLogDicts] return PaginatedResult( items=items, @@ -1500,7 +1527,7 @@ class ChatObjects: { "type": "log", "createdAt": log_timestamp, - "item": ChatLog(**createdLog).model_dump() + "item": ChatLog(**{**createdLog, "mandateId": createdLog.get("mandateId") or self.mandateId or "", "featureInstanceId": createdLog.get("featureInstanceId") or self.featureInstanceId or ""}).model_dump() } )) except Exception as e: @@ -1508,7 +1535,13 @@ class ChatObjects: logger.debug(f"Could not emit log event: {e}") # Return validated ChatLog instance - return ChatLog(**createdLog) + # Ensure mandateId and featureInstanceId are set + log_dict = dict(createdLog) + if "mandateId" not in log_dict or not log_dict["mandateId"]: + log_dict["mandateId"] = self.mandateId or "" + if "featureInstanceId" not in log_dict or not log_dict["featureInstanceId"]: + log_dict["featureInstanceId"] = self.featureInstanceId or "" + return ChatLog(**log_dict) # Stats methods @@ -1533,7 +1566,8 @@ class ChatObjects: # Return all stats records sorted by creation time stats.sort(key=lambda x: x.get("created_at", "")) - return [ChatStat(**stat) for stat in stats] + # Ensure mandateId and featureInstanceId are set for each stat + return [ChatStat(**{**stat, "mandateId": stat.get("mandateId") or self.mandateId or "", "featureInstanceId": stat.get("featureInstanceId") or self.featureInstanceId or ""}) for stat in stats] def createStat(self, statData: Dict[str, Any]) -> ChatStat: @@ -1556,7 +1590,13 @@ class ChatObjects: created = self.db.recordCreate(ChatStat, stat) # Return the created ChatStat - return ChatStat(**created) + # Ensure mandateId and featureInstanceId are set + stat_dict = dict(created) + if "mandateId" not in stat_dict or not stat_dict["mandateId"]: + stat_dict["mandateId"] = self.mandateId or "" + if "featureInstanceId" not in stat_dict or not stat_dict["featureInstanceId"]: + stat_dict["featureInstanceId"] = self.featureInstanceId or "" + return ChatStat(**stat_dict) except Exception as e: logger.error(f"Error creating workflow stat: {str(e)}") raise @@ -1612,7 +1652,9 @@ class ChatObjects: taskNumber=msg.get("taskNumber"), actionNumber=msg.get("actionNumber"), taskProgress=msg.get("taskProgress"), - actionProgress=msg.get("actionProgress") + actionProgress=msg.get("actionProgress"), + mandateId=msg.get("mandateId") or self.mandateId or "", + featureInstanceId=msg.get("featureInstanceId") or self.featureInstanceId or "" ) # Use publishedAt as the timestamp for chronological ordering @@ -1630,7 +1672,9 @@ class ChatObjects: if afterTimestamp is not None and logTimestamp <= afterTimestamp: continue - chatLog = ChatLog(**log) + # Ensure mandateId and featureInstanceId are set + log_dict = {**log, "mandateId": log.get("mandateId") or self.mandateId or "", "featureInstanceId": log.get("featureInstanceId") or self.featureInstanceId or ""} + chatLog = ChatLog(**log_dict) items.append({ "type": "log", "createdAt": logTimestamp, diff --git a/modules/features/chatbot/mainChatbot.py b/modules/features/chatbot/mainChatbot.py index 451a28f8..6a172adf 100644 --- a/modules/features/chatbot/mainChatbot.py +++ b/modules/features/chatbot/mainChatbot.py @@ -9,6 +9,7 @@ This module also handles feature initialization and RBAC catalog registration. """ import logging +from typing import Dict, List, Any # Feature metadata for RBAC catalog FEATURE_CODE = "chatbot" @@ -34,12 +35,47 @@ RESOURCE_OBJECTS = [ { "objectKey": "resource.feature.chatbot.start", "label": {"en": "Start Chatbot", "de": "Chatbot starten", "fr": "Démarrer chatbot"}, - "meta": {"endpoint": "/api/chatbot/start/stream", "method": "POST"} + "meta": {"endpoint": "/api/chatbot/{instanceId}/start/stream", "method": "POST"} }, { "objectKey": "resource.feature.chatbot.stop", "label": {"en": "Stop Chatbot", "de": "Chatbot stoppen", "fr": "Arrêter chatbot"}, - "meta": {"endpoint": "/api/chatbot/stop/{workflowId}", "method": "POST"} + "meta": {"endpoint": "/api/chatbot/{instanceId}/stop/{workflowId}", "method": "POST"} + }, +] + +# DATA Objects for RBAC catalog (tables/entities) +# Used for AccessRules on data-level permissions +DATA_OBJECTS = [ + { + "objectKey": "data.feature.chatbot.ChatWorkflow", + "label": {"en": "Chat Workflow", "de": "Chat-Workflow", "fr": "Workflow de chat"}, + "meta": {"table": "ChatWorkflow", "fields": ["id", "name", "status", "mandateId", "featureInstanceId"]} + }, + { + "objectKey": "data.feature.chatbot.ChatMessage", + "label": {"en": "Chat Message", "de": "Chat-Nachricht", "fr": "Message de chat"}, + "meta": {"table": "ChatMessage", "fields": ["id", "workflowId", "message", "role", "publishedAt"]} + }, + { + "objectKey": "data.feature.chatbot.ChatLog", + "label": {"en": "Chat Log", "de": "Chat-Log", "fr": "Journal de chat"}, + "meta": {"table": "ChatLog", "fields": ["id", "workflowId", "message", "type", "timestamp"]} + }, + { + "objectKey": "data.feature.chatbot.ChatDocument", + "label": {"en": "Chat Document", "de": "Chat-Dokument", "fr": "Document de chat"}, + "meta": {"table": "ChatDocument", "fields": ["id", "messageId", "fileId", "fileName", "fileSize", "mimeType"]} + }, + { + "objectKey": "data.feature.chatbot.ChatStat", + "label": {"en": "Chat Statistics", "de": "Chat-Statistiken", "fr": "Statistiques de chat"}, + "meta": {"table": "ChatStat", "fields": ["id", "workflowId", "processingTime", "bytesSent", "bytesReceived", "errorCount"]} + }, + { + "objectKey": "data.feature.chatbot.*", + "label": {"en": "All Chatbot Data", "de": "Alle Chatbot-Daten", "fr": "Toutes les données chatbot"}, + "meta": {"wildcard": True, "description": "Wildcard for all chatbot data tables"} }, ] @@ -104,9 +140,23 @@ def getTemplateRoles(): return TEMPLATE_ROLES +def getDataObjects(): + """Return DATA objects for RBAC catalog registration.""" + return DATA_OBJECTS + + def registerFeature(catalogService) -> bool: - """Register this feature's RBAC objects in the catalog.""" + """ + Register this feature's RBAC objects in the catalog. + + Args: + catalogService: The RBAC catalog service instance + + Returns: + True if registration was successful + """ try: + # Register UI objects for uiObj in UI_OBJECTS: catalogService.registerUiObject( featureCode=FEATURE_CODE, @@ -115,6 +165,7 @@ def registerFeature(catalogService) -> bool: meta=uiObj.get("meta") ) + # Register Resource objects for resObj in RESOURCE_OBJECTS: catalogService.registerResourceObject( featureCode=FEATURE_CODE, @@ -123,10 +174,210 @@ def registerFeature(catalogService) -> bool: meta=resObj.get("meta") ) + # Register DATA objects (tables/entities) + for dataObj in DATA_OBJECTS: + catalogService.registerDataObject( + featureCode=FEATURE_CODE, + objectKey=dataObj["objectKey"], + label=dataObj["label"], + meta=dataObj.get("meta") + ) + + # Sync template roles to database (with AccessRules) + _syncTemplateRolesToDb() + + logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI, {len(RESOURCE_OBJECTS)} resource, {len(DATA_OBJECTS)} data objects") return True + except Exception as e: - logging.getLogger(__name__).error(f"Failed to register feature '{FEATURE_CODE}': {e}") + logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}") return False + + +def _syncTemplateRolesToDb() -> int: + """ + Sync template roles and their AccessRules to the database. + Creates global template roles (mandateId=None) if they don't exist. + + Returns: + Number of roles created/updated + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext + + rootInterface = getRootInterface() + db = rootInterface.db + + # Get existing template roles for this feature + existingRoles = db.getRecordset( + Role, + recordFilter={"featureCode": FEATURE_CODE, "mandateId": None} + ) + existingRoleLabels = {r.get("roleLabel"): r.get("id") for r in existingRoles} + + createdCount = 0 + for roleTemplate in TEMPLATE_ROLES: + roleLabel = roleTemplate["roleLabel"] + + if roleLabel in existingRoleLabels: + roleId = existingRoleLabels[roleLabel] + logger.debug(f"Template role '{roleLabel}' already exists with ID {roleId}") + + # Ensure AccessRules exist for this role + _ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", [])) + else: + # Create new template role + newRole = Role( + roleLabel=roleLabel, + description=roleTemplate.get("description", {}), + featureCode=FEATURE_CODE, + mandateId=None, # Global template + featureInstanceId=None, + isSystemRole=False + ) + createdRole = db.recordCreate(Role, newRole.model_dump()) + roleId = createdRole.get("id") + + # Create AccessRules for this role + _ensureAccessRulesForRole(db, roleId, roleTemplate.get("accessRules", [])) + + logger.info(f"Created template role '{roleLabel}' with ID {roleId}") + createdCount += 1 + + if createdCount > 0: + logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles") + + # Repair instance-specific roles that are missing AccessRules + _repairInstanceRolesAccessRules(db, existingRoleLabels) + + return createdCount + + except Exception as e: + logger.error(f"Error syncing template roles for feature '{FEATURE_CODE}': {e}") + return 0 + + +def _repairInstanceRolesAccessRules(db, templateRoleLabels: Dict[str, str]) -> int: + """ + Repair instance-specific roles by copying AccessRules from their template roles. + This ensures instance roles created before AccessRules were defined get updated. + + Args: + db: Database connector + templateRoleLabels: Dict mapping roleLabel to template role ID + + Returns: + Number of instance roles repaired + """ + from modules.datamodels.datamodelRbac import Role, AccessRule, AccessRuleContext + + repairedCount = 0 + + # Get all instance-specific roles for this feature (mandateId is NOT None) + allRoles = db.getRecordset(Role, recordFilter={"featureCode": FEATURE_CODE}) + instanceRoles = [r for r in allRoles if r.get("mandateId") is not None] + + for instanceRole in instanceRoles: + roleLabel = instanceRole.get("roleLabel") + instanceRoleId = instanceRole.get("id") + + # Find matching template role + templateRoleId = templateRoleLabels.get(roleLabel) + if not templateRoleId: + continue + + # Check if instance role has AccessRules + existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": instanceRoleId}) + if existingRules: + continue # Already has rules, skip + + # Copy AccessRules from template role + templateRules = db.getRecordset(AccessRule, recordFilter={"roleId": templateRoleId}) + if not templateRules: + continue # Template has no rules + + for rule in templateRules: + newRule = AccessRule( + roleId=instanceRoleId, + context=rule.get("context"), + item=rule.get("item"), + view=rule.get("view", False), + read=rule.get("read"), + create=rule.get("create"), + update=rule.get("update"), + delete=rule.get("delete"), + ) + db.recordCreate(AccessRule, newRule.model_dump()) + + logger.info(f"Repaired instance role '{roleLabel}' (ID: {instanceRoleId}): copied {len(templateRules)} AccessRules from template") + repairedCount += 1 + + if repairedCount > 0: + logger.info(f"Feature '{FEATURE_CODE}': Repaired {repairedCount} instance roles with missing AccessRules") + + return repairedCount + + +def _ensureAccessRulesForRole(db, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int: + """ + Ensure AccessRules exist for a role based on templates. + + Args: + db: Database connector + roleId: Role ID + ruleTemplates: List of rule templates + + Returns: + Number of rules created + """ + from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext + + # Get existing rules for this role + existingRules = db.getRecordset(AccessRule, recordFilter={"roleId": roleId}) + + # Create a set of existing rule signatures to avoid duplicates + existingSignatures = set() + for rule in existingRules: + sig = (rule.get("context"), rule.get("item")) + existingSignatures.add(sig) + + createdCount = 0 + for template in ruleTemplates: + context = template.get("context", "UI") + item = template.get("item") + sig = (context, item) + + if sig in existingSignatures: + continue + + # Map context string to enum + if context == "UI": + contextEnum = AccessRuleContext.UI + elif context == "DATA": + contextEnum = AccessRuleContext.DATA + elif context == "RESOURCE": + contextEnum = AccessRuleContext.RESOURCE + else: + contextEnum = context + + newRule = AccessRule( + roleId=roleId, + context=contextEnum, + item=item, + view=template.get("view", False), + read=template.get("read"), + create=template.get("create"), + update=template.get("update"), + delete=template.get("delete"), + ) + db.recordCreate(AccessRule, newRule.model_dump()) + createdCount += 1 + + if createdCount > 0: + logger.debug(f"Created {createdCount} AccessRules for role {roleId}") + + return createdCount import json import uuid import asyncio @@ -139,6 +390,7 @@ from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, Operati from modules.datamodels.datamodelDocref import DocumentReferenceList, DocumentItemReference from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp from modules.services import getInterface as getServices +from modules.features.chatbot import interfaceFeatureChatbot from modules.features.chatbot.eventManager import get_event_manager from modules.workflows.methods.methodAi.methodAi import MethodAi from modules.connectors.connectorPreprocessor import PreprocessorConnector @@ -146,7 +398,8 @@ from modules.features.chatbot.chatbotConstants import ( get_initial_analysis_prompt, generate_conversation_name, get_final_answer_system_prompt, - get_final_answer_prompt_with_results + get_final_answer_prompt_with_results, + get_empty_results_retry_instructions ) import base64 @@ -184,7 +437,8 @@ async def chatProcess( currentUser: User, mandateId: str, userInput: UserInputRequest, - workflowId: Optional[str] = None + workflowId: Optional[str] = None, + featureInstanceId: Optional[str] = None ) -> ChatWorkflow: """ Simple chatbot processing - analyze user input and generate queries. @@ -200,14 +454,25 @@ async def chatProcess( mandateId: Mandate context (from RequestContext / X-Mandate-Id header) userInput: User input request workflowId: Optional workflow ID to continue existing conversation + featureInstanceId: Optional feature instance ID for instance-level isolation Returns: ChatWorkflow instance """ try: - # Get services with mandate context + # Get services normally (for other services like chat, ai, etc.) services = getServices(currentUser, None, mandateId=mandateId) - interfaceDbChat = services.interfaceDbChat + + # Replace interfaceDbChat with chatbot-specific interface that supports featureInstanceId + # This ensures instance-level data isolation + interfaceDbChat = interfaceFeatureChatbot.getInterface( + currentUser, + mandateId=mandateId, + featureInstanceId=featureInstanceId + ) + + # Update services to use the chatbot-specific interface + services.interfaceDbChat = interfaceDbChat # Get event manager and create queue if needed event_manager = get_event_manager() @@ -218,6 +483,10 @@ async def chatProcess( if not workflow: raise ValueError(f"Workflow {workflowId} not found") + # Verify workflow belongs to this instance if instanceId is provided + if featureInstanceId and workflow.featureInstanceId != featureInstanceId: + raise ValueError(f"Workflow {workflowId} does not belong to instance '{featureInstanceId}'") + # Resume workflow: increment round number new_round = workflow.currentRound + 1 interfaceDbChat.updateWorkflow(workflowId, { @@ -243,6 +512,7 @@ async def chatProcess( workflowData = { "id": str(uuid.uuid4()), "mandateId": mandateId, + "featureInstanceId": featureInstanceId, "status": "running", "name": conversation_name, "currentRound": 1, @@ -261,7 +531,10 @@ async def chatProcess( event_manager.create_queue(workflow.id) # Reload workflow to get current message count - workflow = interfaceDbChat.getWorkflow(workflow.id) + workflow_id = workflow.id + workflow = interfaceDbChat.getWorkflow(workflow_id) + if not workflow: + raise ValueError(f"Failed to reload workflow {workflow_id}") # Process uploaded files and create ChatDocuments user_documents = [] @@ -297,13 +570,15 @@ async def chatProcess( logger.error(f"Error processing file ID {fileId}: {e}", exc_info=True) # Store user message + # Get message count safely (workflow.messages might be None or empty) + message_count = len(workflow.messages) if workflow.messages else 0 userMessageData = { "id": f"msg_{uuid.uuid4()}", "workflowId": workflow.id, "message": userInput.prompt, "role": "user", "status": "first" if workflowId is None else "step", - "sequenceNr": len(workflow.messages) + 1, + "sequenceNr": message_count + 1, "publishedAt": getUtcTimestamp(), "roundNumber": workflow.currentRound, "taskNumber": 0, @@ -1046,6 +1321,11 @@ async def _processChatbotMessage( "simpleMode": True }) + # Check if workflow was stopped during analysis + if await _check_workflow_stopped(interfaceDbChat, workflowId): + logger.info(f"Workflow {workflowId} was stopped during analysis, aborting processing") + return + # Extract content from ActionResult analysis_content = None if analysis_result.success and analysis_result.documents: @@ -1618,8 +1898,18 @@ async def _processChatbotMessage( ) ) + # Double-check workflow wasn't stopped right before AI call + if await _check_workflow_stopped(interfaceDbChat, workflowId): + logger.info(f"Workflow {workflowId} was stopped before final answer AI call, aborting") + return + answerResponse = await services.ai.callAi(answerRequest) + # Check immediately after AI call completes - if stopped, abort without processing or storing + if await _check_workflow_stopped(interfaceDbChat, workflowId): + logger.info(f"Workflow {workflowId} was stopped during final answer AI call, aborting without storing message") + return + # Check for errors in AI response if answerResponse.errorCount > 0: logger.error(f"AI call failed with errorCount={answerResponse.errorCount}: {answerResponse.content}") @@ -1628,9 +1918,9 @@ async def _processChatbotMessage( finalAnswer = answerResponse.content logger.info("Final answer generated") - # Check if workflow was stopped during AI call - if so, don't store the message + # Check again after generating answer (in case it was stopped while generating) if await _check_workflow_stopped(interfaceDbChat, workflowId): - logger.info(f"Workflow {workflowId} was stopped during final answer generation, not storing message") + logger.info(f"Workflow {workflowId} was stopped after final answer generation, not storing message") return # Reload workflow to get current message count diff --git a/modules/features/chatbot/routeFeatureChatbot.py b/modules/features/chatbot/routeFeatureChatbot.py index ee05e2ac..e6e9c626 100644 --- a/modules/features/chatbot/routeFeatureChatbot.py +++ b/modules/features/chatbot/routeFeatureChatbot.py @@ -9,10 +9,11 @@ import logging import json import asyncio import math +import uuid from typing import Optional, Any, Dict, Union from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, status from fastapi.responses import StreamingResponse -from modules.shared.timeUtils import parseTimestamp +from modules.shared.timeUtils import parseTimestamp, getUtcTimestamp # Import auth modules from modules.auth import limiter, getRequestContext, RequestContext @@ -20,10 +21,12 @@ from modules.auth import limiter, getRequestContext, RequestContext # Import interfaces from . import interfaceFeatureChatbot as interfaceDbChat from modules.interfaces.interfaceRbac import getRecordsetWithRBAC +from modules.interfaces.interfaceDbApp import getRootInterface +from modules.interfaces.interfaceFeatures import getFeatureInterface # Import models from .datamodelFeatureChatbot import ChatWorkflow, UserInputRequest, WorkflowModeEnum -from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse +from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata # Import chatbot feature from . import chatProcess @@ -42,14 +45,71 @@ router = APIRouter( responses={404: {"description": "Not found"}} ) -def _getServiceChat(context: RequestContext): - return interfaceDbChat.getInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) +def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None): + """Get chatbot interface with instance context.""" + mandateId = str(context.mandateId) if context.mandateId else None + return interfaceDbChat.getInterface( + context.user, + mandateId=mandateId, + featureInstanceId=instanceId + ) + + +async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: + """ + Validate that the user has access to the feature instance. + Returns the mandateId for the instance. + + Args: + instanceId: The FeatureInstance ID from URL + context: The request context with user info + + Returns: + mandateId of the instance + + Raises: + HTTPException 404 if instance not found + HTTPException 403 if user doesn't have access + """ + rootInterface = getRootInterface() + featureInterface = getFeatureInterface(rootInterface.db) + + instance = featureInterface.getFeatureInstance(instanceId) + if not instance: + raise HTTPException( + status_code=404, + detail=f"Feature instance '{instanceId}' not found" + ) + + # Verify it's a chatbot instance + if instance.featureCode != "chatbot": + raise HTTPException( + status_code=400, + detail=f"Instance '{instanceId}' is not a chatbot instance" + ) + + # Verify user has access to this instance + if not context.isSysAdmin: + # Check if user has FeatureAccess for this instance + featureAccesses = rootInterface.getFeatureAccessesForUser(str(context.user.id)) + hasAccess = any( + str(fa.featureInstanceId) == instanceId and fa.enabled + for fa in featureAccesses + ) + if not hasAccess: + raise HTTPException( + status_code=403, + detail=f"Access denied to feature instance '{instanceId}'" + ) + + return str(instance.mandateId) # Chatbot streaming endpoint (SSE) -@router.post("/start/stream") +@router.post("/{instanceId}/start/stream") @limiter.limit("120/minute") async def stream_chatbot_start( request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue (can also be in request body)"), userInput: UserInputRequest = Body(...), context: RequestContext = Depends(getRequestContext) @@ -59,10 +119,13 @@ async def stream_chatbot_start( Streams progress updates in real-time via Server-Sent Events. workflowId can be provided either: - - As a query parameter: /api/chatbot/start/stream?workflowId=xxx + - As a query parameter: /api/chatbot/{instanceId}/start/stream?workflowId=xxx - In the request body as part of UserInputRequest - Query parameter takes precedence if both are provided """ + # Validate instance access + mandateId = await _validateInstanceAccess(instanceId, context) + event_manager = get_event_manager() try: @@ -70,7 +133,15 @@ async def stream_chatbot_start( final_workflow_id = workflowId or userInput.workflowId # Start background processing (this will create the workflow and event queue) - workflow = await chatProcess(context.user, str(context.mandateId), userInput, final_workflow_id) + # Pass featureInstanceId to chatProcess + workflow = await chatProcess(context.user, mandateId, userInput, final_workflow_id, featureInstanceId=instanceId) + + # Check if workflow was created successfully + if not workflow: + raise HTTPException( + status_code=500, + detail="Failed to create or load workflow" + ) # Get event queue for the workflow queue = event_manager.get_queue(workflow.id) @@ -82,7 +153,7 @@ async def stream_chatbot_start( """Async generator for SSE events - pure event-driven streaming (no polling).""" try: # Get interface for initial data and status checks - interfaceDbChat = _getServiceChat(context) + interfaceDbChat = _getServiceChat(context, instanceId) # Get current workflow to check if resuming and get current round current_workflow = interfaceDbChat.getWorkflow(workflow.id) @@ -233,16 +304,56 @@ async def stream_chatbot_start( # Workflow stop endpoint -@router.post("/{workflowId}/stop", response_model=ChatWorkflow) +@router.post("/{instanceId}/stop/{workflowId}", response_model=ChatWorkflow) @limiter.limit("120/minute") async def stop_chatbot( request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), workflowId: str = Path(..., description="ID of the workflow to stop"), context: RequestContext = Depends(getRequestContext) ) -> ChatWorkflow: """Stops a running chatbot workflow.""" + # Validate instance access + await _validateInstanceAccess(instanceId, context) + try: - workflow = await chatStop(context.user, workflowId) + # Get chatbot interface with instance context + interfaceDbChat = _getServiceChat(context, instanceId) + + # Get workflow to verify it exists and belongs to this instance + workflow = interfaceDbChat.getWorkflow(workflowId) + if not workflow: + raise HTTPException( + status_code=404, + detail=f"Workflow {workflowId} not found" + ) + + # Verify workflow belongs to this instance + if workflow.featureInstanceId and workflow.featureInstanceId != instanceId: + raise HTTPException( + status_code=403, + detail=f"Workflow {workflowId} does not belong to instance {instanceId}" + ) + + # Update workflow status to stopped + interfaceDbChat.updateWorkflow(workflowId, { + "status": "stopped", + "lastActivity": getUtcTimestamp() + }) + + # Store log entry + interfaceDbChat.createLog({ + "id": f"log_{uuid.uuid4()}", + "workflowId": workflowId, + "message": "Workflow stopped by user", + "type": "warning", + "status": "stopped", + "timestamp": getUtcTimestamp(), + "roundNumber": workflow.currentRound if workflow else 1 + }) + + # Reload workflow to return updated version + workflow = interfaceDbChat.getWorkflow(workflowId) # Emit stopped event to active streams event_manager = get_event_manager() @@ -254,29 +365,35 @@ async def stop_chatbot( message="Workflow stopped by user", step="stopped" ) - logger.info(f"Emitted stopped event for workflow {workflowId}") + logger.info(f"Stopped workflow {workflowId} and emitted stopped event") return workflow + except HTTPException: + raise except Exception as e: - logger.error(f"Error in stop_chatbot: {str(e)}") + logger.error(f"Error in stop_chatbot: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=str(e) ) # Delete chatbot workflow endpoint -@router.delete("/{workflowId}", response_model=Dict[str, Any]) +@router.delete("/{instanceId}/{workflowId}", response_model=Dict[str, Any]) @limiter.limit("120/minute") async def delete_chatbot( request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), workflowId: str = Path(..., description="ID of the workflow to delete"), context: RequestContext = Depends(getRequestContext) ) -> Dict[str, Any]: """Deletes a chatbot workflow and its associated data.""" + # Validate instance access + mandateId = await _validateInstanceAccess(instanceId, context) + try: # Get service center - interfaceDbChat = _getServiceChat(context) + interfaceDbChat = _getServiceChat(context, instanceId) # Check workflow access and permission using RBAC workflows = getRecordsetWithRBAC( @@ -300,6 +417,14 @@ async def delete_chatbot( detail=f"Workflow {workflowId} is not a chatbot workflow" ) + # Verify workflow belongs to this instance + workflow_instance_id = workflow_data.get("featureInstanceId") + if workflow_instance_id != instanceId: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Workflow {workflowId} does not belong to instance '{instanceId}'" + ) + # Check if user has permission to delete using RBAC if not interfaceDbChat.checkRbacPermission(ChatWorkflow, "delete", workflowId): raise HTTPException( @@ -330,10 +455,11 @@ async def delete_chatbot( ) # List chatbot threads/workflows or get specific thread details -@router.get("/threads") +@router.get("/{instanceId}/threads") @limiter.limit("120/minute") async def get_chatbot_threads( request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), workflowId: Optional[str] = Query(None, description="Optional workflow ID to get details and chat data for a specific thread"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object (only used when workflowId is not provided)"), context: RequestContext = Depends(getRequestContext) @@ -344,8 +470,11 @@ async def get_chatbot_threads( - If workflowId is provided: Returns the workflow details and all chat data (messages, logs, stats) - If workflowId is not provided: Returns a paginated list of all workflows """ + # Validate instance access + mandateId = await _validateInstanceAccess(instanceId, context) + try: - interfaceDbChat = _getServiceChat(context) + interfaceDbChat = _getServiceChat(context, instanceId) # If workflowId is provided, return single workflow with chat data if workflowId: