# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Chatbot routes for the backend API. Implements simple chatbot endpoints using direct AI center calls via chatbot feature. """ import logging import json import asyncio import math from typing import Optional, Any, Dict, Union from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, status from fastapi.responses import StreamingResponse from modules.shared.timeUtils import parseTimestamp # Import auth modules from modules.auth import limiter, getCurrentUser # Import interfaces import modules.interfaces.interfaceDbChatObjects as interfaceDbChatObjects from modules.interfaces.interfaceRbac import getRecordsetWithRBAC # Import models from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse # Import chatbot feature from modules.features.chatbot import chatProcess from modules.features.chatbot.eventManager import get_event_manager # Import workflow control functions from modules.features.workflow import chatStop # Configure logger logger = logging.getLogger(__name__) # Create router for chatbot endpoints router = APIRouter( prefix="/api/chatbot", tags=["Chatbot"], responses={404: {"description": "Not found"}} ) def getServiceChat(currentUser: User): return interfaceDbChatObjects.getInterface(currentUser) # Chatbot streaming endpoint (SSE) @router.post("/start/stream") @limiter.limit("120/minute") async def stream_chatbot_start( request: Request, workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue (can also be in request body)"), userInput: UserInputRequest = Body(...), currentUser: User = Depends(getCurrentUser) ) -> StreamingResponse: """ Starts a new chatbot workflow or continues an existing one with SSE streaming. Streams progress updates in real-time via Server-Sent Events. workflowId can be provided either: - As a query parameter: /api/chatbot/start/stream?workflowId=xxx - In the request body as part of UserInputRequest - Query parameter takes precedence if both are provided """ event_manager = get_event_manager() try: # Use workflowId from query parameter if provided, otherwise from request body final_workflow_id = workflowId or userInput.workflowId # Start background processing (this will create the workflow and event queue) workflow = await chatProcess(currentUser, userInput, final_workflow_id) # Get event queue for the workflow queue = event_manager.get_queue(workflow.id) if not queue: # Create queue if it doesn't exist queue = event_manager.create_queue(workflow.id) async def event_stream(): """Async generator for SSE events - pure event-driven streaming (no polling).""" try: # Get interface for initial data and status checks interfaceDbChat = getServiceChat(currentUser) # Get current workflow to check if resuming and get current round current_workflow = interfaceDbChat.getWorkflow(workflow.id) current_round = current_workflow.currentRound if current_workflow else None is_resuming = final_workflow_id is not None and current_round and current_round > 1 # Send initial chat data (exact format as chatData endpoint) - only once at start try: chatData = interfaceDbChat.getUnifiedChatData(workflow.id, None) if chatData.get("items"): # Filter items by round number if resuming filtered_items = [] for item in chatData["items"]: if is_resuming and current_round: # Get round number from item item_round = None item_data = item.get("item") if item_data: # Handle both dict and object access if isinstance(item_data, dict): item_round = item_data.get("roundNumber") elif hasattr(item_data, "roundNumber"): item_round = item_data.roundNumber # When resuming, only include items from current round onwards # Exclude items without roundNumber (they're from old rounds before roundNumber was added) # Exclude items with roundNumber < current_round (from previous rounds) if item_round is None or item_round < current_round: continue # Skip items from previous rounds or without round info filtered_items.append(item) # Emit filtered items for item in filtered_items: # Convert Pydantic models to dicts for JSON serialization serializable_item = { "type": item.get("type"), "createdAt": item.get("createdAt"), "item": item.get("item").dict() if hasattr(item.get("item"), "dict") else item.get("item") } # Emit item directly in exact chatData format: {type, createdAt, item} yield f"data: {json.dumps(serializable_item)}\n\n" except Exception as e: logger.warning(f"Error fetching initial chat data: {e}") # Keepalive interval (30 seconds) keepalive_interval = 30.0 last_keepalive = asyncio.get_event_loop().time() # Status check interval (check workflow status every 5 seconds - less frequent since we're event-driven) status_check_interval = 5.0 last_status_check = asyncio.get_event_loop().time() # Stream events until completion or timeout - pure event-driven (no polling) timeout = 300.0 # 5 minutes max start_time = asyncio.get_event_loop().time() while True: # Check timeout elapsed = asyncio.get_event_loop().time() - start_time if elapsed > timeout: logger.info(f"Stream timeout for workflow {workflow.id}") break # Check for client disconnection if await request.is_disconnected(): logger.info(f"Client disconnected for workflow {workflow.id}") break current_time = asyncio.get_event_loop().time() # Periodically check workflow status (less frequent since we're event-driven) if current_time - last_status_check >= status_check_interval: try: current_workflow = interfaceDbChat.getWorkflow(workflow.id) if current_workflow and current_workflow.status == "stopped": logger.info(f"Workflow {workflow.id} was stopped, closing stream") break except Exception as e: logger.warning(f"Error checking workflow status: {e}") last_status_check = current_time # Get event from queue (pure event-driven - no polling database) try: event = await asyncio.wait_for(queue.get(), timeout=1.0) # Handle different event types event_type = event.get("type") event_data = event.get("data", {}) # Emit chatdata events (messages, logs, stats) in exact chatData format if event_type == "chatdata" and event_data: # Emit item directly in exact chatData format: {type, createdAt, item} chatdata_item = event_data # Ensure item field is serializable (convert Pydantic models to dicts) if isinstance(chatdata_item, dict) and "item" in chatdata_item: item_obj = chatdata_item.get("item") if hasattr(item_obj, "dict"): chatdata_item = chatdata_item.copy() chatdata_item["item"] = item_obj.dict() yield f"data: {json.dumps(chatdata_item)}\n\n" # Handle completion/stopped events to close stream elif event_type == "complete": logger.info(f"Workflow {workflow.id} completed, closing stream") break elif event_type == "stopped": logger.info(f"Workflow {workflow.id} stopped, closing stream") break elif event_type == "error" and event.get("step") == "error": logger.warning(f"Workflow {workflow.id} error, closing stream") break last_keepalive = current_time except asyncio.TimeoutError: # Send keepalive if needed (no events received, but keep connection alive) current_time = asyncio.get_event_loop().time() if current_time - last_keepalive >= keepalive_interval: yield f": keepalive\n\n" last_keepalive = current_time continue except Exception as e: logger.error(f"Error in event stream: {e}") break except Exception as e: logger.error(f"Error in event stream generator: {e}", exc_info=True) finally: # Stream ends - cleanup handled by event manager pass return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" # Disable buffering for nginx } ) except Exception as e: logger.error(f"Error in stream_chatbot_start: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=str(e) ) # Workflow stop endpoint @router.post("/{workflowId}/stop", response_model=ChatWorkflow) @limiter.limit("120/minute") async def stop_chatbot( request: Request, workflowId: str = Path(..., description="ID of the workflow to stop"), currentUser: User = Depends(getCurrentUser) ) -> ChatWorkflow: """Stops a running chatbot workflow.""" try: workflow = await chatStop(currentUser, workflowId) # Emit stopped event to active streams event_manager = get_event_manager() await event_manager.emit_event( context_id=workflowId, event_type="stopped", data={"workflowId": workflowId}, event_category="workflow", message="Workflow stopped by user", step="stopped" ) logger.info(f"Emitted stopped event for workflow {workflowId}") return workflow except Exception as e: logger.error(f"Error in stop_chatbot: {str(e)}") raise HTTPException( status_code=500, detail=str(e) ) # Delete chatbot workflow endpoint @router.delete("/{workflowId}", response_model=Dict[str, Any]) @limiter.limit("120/minute") async def delete_chatbot( request: Request, workflowId: str = Path(..., description="ID of the workflow to delete"), currentUser: User = Depends(getCurrentUser) ) -> Dict[str, Any]: """Deletes a chatbot workflow and its associated data.""" try: # Get service center interfaceDbChat = getServiceChat(currentUser) # Check workflow access and permission using RBAC workflows = getRecordsetWithRBAC( interfaceDbChat.db, ChatWorkflow, currentUser, recordFilter={"id": workflowId} ) if not workflows: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Workflow with ID {workflowId} not found" ) workflow_data = workflows[0] # Check if workflow is a chatbot workflow if workflow_data.get("workflowMode") != WorkflowModeEnum.WORKFLOW_CHATBOT.value: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Workflow {workflowId} is not a chatbot workflow" ) # Check if user has permission to delete using RBAC if not interfaceDbChat.checkRbacPermission(ChatWorkflow, "delete", workflowId): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to delete this workflow" ) # Delete workflow success = interfaceDbChat.deleteWorkflow(workflowId) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete workflow" ) return { "id": workflowId, "message": "Chatbot workflow and associated data deleted successfully" } except HTTPException: raise except Exception as e: logger.error(f"Error in delete_chatbot: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Error deleting chatbot workflow: {str(e)}" ) # List chatbot threads/workflows or get specific thread details @router.get("/threads") @limiter.limit("120/minute") async def get_chatbot_threads( request: Request, workflowId: Optional[str] = Query(None, description="Optional workflow ID to get details and chat data for a specific thread"), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object (only used when workflowId is not provided)"), currentUser: User = Depends(getCurrentUser) ) -> Union[PaginatedResponse[ChatWorkflow], Dict[str, Any]]: """ List all chatbot workflows (threads) for the current user, or get details and chat data for a specific thread. - If workflowId is provided: Returns the workflow details and all chat data (messages, logs, stats) - If workflowId is not provided: Returns a paginated list of all workflows """ try: interfaceDbChat = getServiceChat(currentUser) # If workflowId is provided, return single workflow with chat data if workflowId: workflow = interfaceDbChat.getWorkflow(workflowId) if not workflow: raise HTTPException( status_code=404, detail=f"Workflow with ID {workflowId} not found" ) # Normalize workflow data to match ChatWorkflow model requirements # Convert workflow object to dict if needed, and normalize None values if hasattr(workflow, 'model_dump'): workflow_dict = workflow.model_dump() elif hasattr(workflow, 'dict'): workflow_dict = workflow.dict() elif isinstance(workflow, dict): workflow_dict = dict(workflow) else: workflow_dict = workflow # Set maxSteps to default value of 10 if None (as per ChatWorkflow model) if workflow_dict.get("maxSteps") is None: workflow_dict["maxSteps"] = 10 # Get unified chat data for this workflow chatData = interfaceDbChat.getUnifiedChatData(workflowId, None) return { "workflow": workflow_dict, "chatData": chatData } # Otherwise, return paginated list of workflows # Parse pagination parameter paginationParams = None if pagination: try: paginationDict = json.loads(pagination) paginationParams = PaginationParams(**paginationDict) if paginationDict else None except (json.JSONDecodeError, ValueError) as e: raise HTTPException( status_code=400, detail=f"Invalid pagination parameter: {str(e)}" ) # Get all workflows filtered by mandateId (RBAC handles this automatically) # We get all workflows first to filter by workflowMode before pagination all_workflows = interfaceDbChat.getWorkflows(pagination=None) # Filter to only include chatbot workflows chatbot_workflows_data = [ wf for wf in all_workflows if wf.get("workflowMode") == WorkflowModeEnum.WORKFLOW_CHATBOT.value ] # Apply pagination if requested if paginationParams: # Apply sorting if provided if paginationParams.sort: chatbot_workflows_data = interfaceDbChat._applySorting(chatbot_workflows_data, paginationParams.sort) # Count total items after filtering totalItems = len(chatbot_workflows_data) totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 # Apply pagination (skip/limit) startIdx = (paginationParams.page - 1) * paginationParams.pageSize endIdx = startIdx + paginationParams.pageSize workflows = chatbot_workflows_data[startIdx:endIdx] else: workflows = chatbot_workflows_data totalItems = len(chatbot_workflows_data) totalPages = 1 # Normalize workflow data to match ChatWorkflow model requirements # Convert None values to defaults before response validation normalized_workflows = [] for wf in workflows: normalized_wf = dict(wf) # Create a copy # Set maxSteps to default value of 10 if None (as per ChatWorkflow model) if normalized_wf.get("maxSteps") is None: normalized_wf["maxSteps"] = 10 normalized_workflows.append(normalized_wf) # Create paginated response from modules.datamodels.datamodelPagination import PaginationMetadata metadata = PaginationMetadata( currentPage=paginationParams.page if paginationParams else 1, pageSize=paginationParams.pageSize if paginationParams else len(workflows), totalItems=totalItems, totalPages=totalPages, sort=paginationParams.sort if paginationParams else [], filters=paginationParams.filters if paginationParams else None ) return PaginatedResponse( items=normalized_workflows, pagination=metadata ) except HTTPException: raise except Exception as e: logger.error(f"Error getting chatbot threads: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Error getting chatbot threads: {str(e)}" )