513 lines
24 KiB
Python
513 lines
24 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Chatbot routes for the backend API.
|
|
Implements simple chatbot endpoints using direct AI center calls via chatbot feature.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
import asyncio
|
|
import math
|
|
from typing import Optional, Any, Dict, Union
|
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, status
|
|
from fastapi.responses import StreamingResponse
|
|
from modules.shared.timeUtils import parseTimestamp
|
|
|
|
# Import auth modules
|
|
from modules.auth import limiter, getCurrentUser
|
|
|
|
# Import interfaces
|
|
import modules.interfaces.interfaceDbChatObjects as interfaceDbChatObjects
|
|
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
|
|
|
# Import models
|
|
from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse
|
|
|
|
# Import chatbot feature
|
|
from modules.features.chatbot import chatProcess
|
|
from modules.features.chatbot.eventManager import get_event_manager
|
|
|
|
# Import workflow control functions
|
|
from modules.features.workflow import chatStop
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create router for chatbot endpoints
|
|
router = APIRouter(
|
|
prefix="/api/chatbot",
|
|
tags=["Chatbot"],
|
|
responses={404: {"description": "Not found"}}
|
|
)
|
|
|
|
def getServiceChat(currentUser: User):
|
|
return interfaceDbChatObjects.getInterface(currentUser)
|
|
|
|
# Chatbot streaming endpoint (SSE)
|
|
@router.post("/start/stream")
|
|
@limiter.limit("120/minute")
|
|
async def stream_chatbot_start(
|
|
request: Request,
|
|
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue (can also be in request body)"),
|
|
userInput: UserInputRequest = Body(...),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> StreamingResponse:
|
|
"""
|
|
Starts a new chatbot workflow or continues an existing one with SSE streaming.
|
|
Streams progress updates in real-time via Server-Sent Events.
|
|
|
|
workflowId can be provided either:
|
|
- As a query parameter: /api/chatbot/start/stream?workflowId=xxx
|
|
- In the request body as part of UserInputRequest
|
|
- Query parameter takes precedence if both are provided
|
|
"""
|
|
event_manager = get_event_manager()
|
|
|
|
try:
|
|
# Use workflowId from query parameter if provided, otherwise from request body
|
|
final_workflow_id = workflowId or userInput.workflowId
|
|
|
|
# Start background processing (this will create the workflow and event queue)
|
|
workflow = await chatProcess(currentUser, userInput, final_workflow_id)
|
|
|
|
# Get event queue for the workflow
|
|
queue = event_manager.get_queue(workflow.id)
|
|
if not queue:
|
|
# Create queue if it doesn't exist
|
|
queue = event_manager.create_queue(workflow.id)
|
|
|
|
async def event_stream():
|
|
"""Async generator for SSE events."""
|
|
try:
|
|
# Get interface for status checks and chat data
|
|
interfaceDbChat = getServiceChat(currentUser)
|
|
|
|
# Get current workflow to check if resuming and get current round
|
|
current_workflow = interfaceDbChat.getWorkflow(workflow.id)
|
|
current_round = current_workflow.currentRound if current_workflow else None
|
|
is_resuming = final_workflow_id is not None and current_round and current_round > 1
|
|
|
|
# Send initial chat data (exact format as chatData endpoint)
|
|
try:
|
|
chatData = interfaceDbChat.getUnifiedChatData(workflow.id, None)
|
|
if chatData.get("items"):
|
|
# Filter items by round number if resuming
|
|
filtered_items = []
|
|
for item in chatData["items"]:
|
|
if is_resuming and current_round:
|
|
# Get round number from item
|
|
item_round = None
|
|
item_data = item.get("item")
|
|
if item_data:
|
|
# Handle both dict and object access
|
|
if isinstance(item_data, dict):
|
|
item_round = item_data.get("roundNumber")
|
|
elif hasattr(item_data, "roundNumber"):
|
|
item_round = item_data.roundNumber
|
|
|
|
# When resuming, only include items from current round onwards
|
|
# Exclude items without roundNumber (they're from old rounds before roundNumber was added)
|
|
# Exclude items with roundNumber < current_round (from previous rounds)
|
|
if item_round is None or item_round < current_round:
|
|
continue # Skip items from previous rounds or without round info
|
|
|
|
filtered_items.append(item)
|
|
|
|
# Emit filtered items
|
|
for item in filtered_items:
|
|
# Convert Pydantic models to dicts for JSON serialization
|
|
serializable_item = {
|
|
"type": item.get("type"),
|
|
"createdAt": item.get("createdAt"),
|
|
"item": item.get("item").dict() if hasattr(item.get("item"), "dict") else item.get("item")
|
|
}
|
|
# Emit item directly in exact chatData format: {type, createdAt, item}
|
|
yield f"data: {json.dumps(serializable_item)}\n\n"
|
|
|
|
# Set initial timestamp for incremental fetching
|
|
if filtered_items:
|
|
timestamps = [parseTimestamp(item.get("createdAt"), default=0) for item in filtered_items]
|
|
last_chatdata_timestamp = max(timestamps) if timestamps else None
|
|
else:
|
|
last_chatdata_timestamp = None
|
|
else:
|
|
last_chatdata_timestamp = None
|
|
except Exception as e:
|
|
logger.warning(f"Error fetching initial chat data: {e}")
|
|
last_chatdata_timestamp = None
|
|
|
|
# Keepalive interval (30 seconds)
|
|
keepalive_interval = 30.0
|
|
last_keepalive = asyncio.get_event_loop().time()
|
|
|
|
# Status check interval (check workflow status every 3 seconds)
|
|
status_check_interval = 3.0
|
|
last_status_check = asyncio.get_event_loop().time()
|
|
|
|
# Chat data fetch interval (fetch chat data every 0.5 seconds for real-time updates)
|
|
chatdata_fetch_interval = 0.5
|
|
last_chatdata_fetch = asyncio.get_event_loop().time()
|
|
|
|
# Stream events until completion or timeout
|
|
timeout = 300.0 # 5 minutes max
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
while True:
|
|
# Check timeout
|
|
elapsed = asyncio.get_event_loop().time() - start_time
|
|
if elapsed > timeout:
|
|
# Timeout - just close stream, don't emit non-chatData format events
|
|
logger.info(f"Stream timeout for workflow {workflow.id}")
|
|
break
|
|
|
|
# Check for client disconnection
|
|
if await request.is_disconnected():
|
|
logger.info(f"Client disconnected for workflow {workflow.id}")
|
|
break
|
|
|
|
current_time = asyncio.get_event_loop().time()
|
|
|
|
# Periodically check workflow status and fetch chat data
|
|
if current_time - last_status_check >= status_check_interval:
|
|
try:
|
|
current_workflow = interfaceDbChat.getWorkflow(workflow.id)
|
|
if current_workflow and current_workflow.status == "stopped":
|
|
logger.info(f"Workflow {workflow.id} was stopped, closing stream")
|
|
# Don't emit stopped event - just close stream
|
|
break
|
|
except Exception as e:
|
|
logger.warning(f"Error checking workflow status: {e}")
|
|
last_status_check = current_time
|
|
|
|
# Periodically fetch and emit chat data
|
|
if current_time - last_chatdata_fetch >= chatdata_fetch_interval:
|
|
try:
|
|
chatData = interfaceDbChat.getUnifiedChatData(workflow.id, last_chatdata_timestamp)
|
|
if chatData.get("items"):
|
|
# Filter items by round number if resuming
|
|
filtered_items = []
|
|
for item in chatData["items"]:
|
|
if is_resuming and current_round:
|
|
# Get round number from item
|
|
item_round = None
|
|
item_data = item.get("item")
|
|
if item_data:
|
|
# Handle both dict and object access
|
|
if isinstance(item_data, dict):
|
|
item_round = item_data.get("roundNumber")
|
|
elif hasattr(item_data, "roundNumber"):
|
|
item_round = item_data.roundNumber
|
|
|
|
# When resuming, only include items from current round onwards
|
|
# Exclude items without roundNumber (they're from old rounds before roundNumber was added)
|
|
# Exclude items with roundNumber < current_round (from previous rounds)
|
|
if item_round is None or item_round < current_round:
|
|
continue # Skip items from previous rounds or without round info
|
|
|
|
filtered_items.append(item)
|
|
|
|
# Emit filtered items directly in exact chatData format: {type, createdAt, item}
|
|
for item in filtered_items:
|
|
# Convert Pydantic models to dicts for JSON serialization
|
|
serializable_item = {
|
|
"type": item.get("type"),
|
|
"createdAt": item.get("createdAt"),
|
|
"item": item.get("item").dict() if hasattr(item.get("item"), "dict") else item.get("item")
|
|
}
|
|
yield f"data: {json.dumps(serializable_item)}\n\n"
|
|
# Update timestamp to only get new items next time
|
|
if chatData["items"]:
|
|
# Parse timestamps and get the maximum
|
|
timestamps = []
|
|
for item in chatData["items"]:
|
|
ts = parseTimestamp(item.get("createdAt"), default=0)
|
|
timestamps.append(ts)
|
|
if timestamps:
|
|
last_chatdata_timestamp = max(timestamps)
|
|
except Exception as e:
|
|
logger.warning(f"Error fetching chat data: {e}")
|
|
last_chatdata_fetch = current_time
|
|
|
|
# Try to get event with timeout
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=1.0)
|
|
|
|
# Only emit chatdata events (messages, logs, stats) in exact chatData format
|
|
# Ignore status/progress/complete/stopped/error events that don't match the format
|
|
if event.get("type") == "chatdata" and event.get("data"):
|
|
# Emit item directly in exact chatData format: {type, createdAt, item}
|
|
chatdata_item = event.get("data")
|
|
# Ensure item field is serializable (convert Pydantic models to dicts)
|
|
if isinstance(chatdata_item, dict) and "item" in chatdata_item:
|
|
item_obj = chatdata_item.get("item")
|
|
if hasattr(item_obj, "dict"):
|
|
chatdata_item = chatdata_item.copy()
|
|
chatdata_item["item"] = item_obj.dict()
|
|
yield f"data: {json.dumps(chatdata_item)}\n\n"
|
|
# Update timestamp for incremental fetching
|
|
if chatdata_item.get("createdAt"):
|
|
last_chatdata_timestamp = parseTimestamp(chatdata_item["createdAt"], default=None)
|
|
|
|
# Check if this is a completion/stopped event to close stream
|
|
if event.get("type") == "complete":
|
|
logger.info(f"Workflow {workflow.id} completed, closing stream")
|
|
break
|
|
elif event.get("type") == "stopped":
|
|
# Workflow was stopped, close stream
|
|
logger.info(f"Workflow {workflow.id} stopped, closing stream")
|
|
break
|
|
elif event.get("type") == "error" and event.get("step") == "error":
|
|
# Final error, close stream
|
|
logger.warning(f"Workflow {workflow.id} error, closing stream")
|
|
break
|
|
|
|
last_keepalive = asyncio.get_event_loop().time()
|
|
except asyncio.TimeoutError:
|
|
# Send keepalive if needed
|
|
current_time = asyncio.get_event_loop().time()
|
|
if current_time - last_keepalive >= keepalive_interval:
|
|
yield f": keepalive\n\n"
|
|
last_keepalive = current_time
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in event stream: {e}")
|
|
yield f"data: {json.dumps({'type': 'error', 'message': f'Stream error: {str(e)}'})}\n\n"
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in event stream generator: {e}", exc_info=True)
|
|
# Don't emit error events that don't match chatData format
|
|
finally:
|
|
# Stream ends - no final event needed as it doesn't match chatData format
|
|
pass
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Disable buffering for nginx
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in stream_chatbot_start: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=str(e)
|
|
)
|
|
|
|
|
|
# Workflow stop endpoint
|
|
@router.post("/{workflowId}/stop", response_model=ChatWorkflow)
|
|
@limiter.limit("120/minute")
|
|
async def stop_chatbot(
|
|
request: Request,
|
|
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> ChatWorkflow:
|
|
"""Stops a running chatbot workflow."""
|
|
try:
|
|
workflow = await chatStop(currentUser, workflowId)
|
|
|
|
# Emit stopped event to active streams
|
|
event_manager = get_event_manager()
|
|
await event_manager.emit_event(
|
|
workflowId,
|
|
"stopped",
|
|
"Workflow stopped by user",
|
|
"stopped"
|
|
)
|
|
logger.info(f"Emitted stopped event for workflow {workflowId}")
|
|
|
|
return workflow
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in stop_chatbot: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=str(e)
|
|
)
|
|
|
|
# Delete chatbot workflow endpoint
|
|
@router.delete("/{workflowId}", response_model=Dict[str, Any])
|
|
@limiter.limit("120/minute")
|
|
async def delete_chatbot(
|
|
request: Request,
|
|
workflowId: str = Path(..., description="ID of the workflow to delete"),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> Dict[str, Any]:
|
|
"""Deletes a chatbot workflow and its associated data."""
|
|
try:
|
|
# Get service center
|
|
interfaceDbChat = getServiceChat(currentUser)
|
|
|
|
# Check workflow access and permission using RBAC
|
|
workflows = getRecordsetWithRBAC(
|
|
interfaceDbChat.db,
|
|
ChatWorkflow,
|
|
currentUser,
|
|
recordFilter={"id": workflowId}
|
|
)
|
|
if not workflows:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Workflow with ID {workflowId} not found"
|
|
)
|
|
|
|
workflow_data = workflows[0]
|
|
|
|
# Check if workflow is a chatbot workflow
|
|
if workflow_data.get("workflowMode") != WorkflowModeEnum.WORKFLOW_CHATBOT.value:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Workflow {workflowId} is not a chatbot workflow"
|
|
)
|
|
|
|
# Check if user has permission to delete using RBAC
|
|
if not interfaceDbChat.checkRbacPermission(ChatWorkflow, "delete", workflowId):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have permission to delete this workflow"
|
|
)
|
|
|
|
# Delete workflow
|
|
success = interfaceDbChat.deleteWorkflow(workflowId)
|
|
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to delete workflow"
|
|
)
|
|
|
|
return {
|
|
"id": workflowId,
|
|
"message": "Chatbot workflow and associated data deleted successfully"
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in delete_chatbot: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error deleting chatbot workflow: {str(e)}"
|
|
)
|
|
|
|
# List chatbot threads/workflows or get specific thread details
|
|
@router.get("/threads")
|
|
@limiter.limit("120/minute")
|
|
async def get_chatbot_threads(
|
|
request: Request,
|
|
workflowId: Optional[str] = Query(None, description="Optional workflow ID to get details and chat data for a specific thread"),
|
|
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object (only used when workflowId is not provided)"),
|
|
currentUser: User = Depends(getCurrentUser)
|
|
) -> Union[PaginatedResponse[ChatWorkflow], Dict[str, Any]]:
|
|
"""
|
|
List all chatbot workflows (threads) for the current user, or get details and chat data for a specific thread.
|
|
|
|
- If workflowId is provided: Returns the workflow details and all chat data (messages, logs, stats)
|
|
- If workflowId is not provided: Returns a paginated list of all workflows
|
|
"""
|
|
try:
|
|
interfaceDbChat = getServiceChat(currentUser)
|
|
|
|
# If workflowId is provided, return single workflow with chat data
|
|
if workflowId:
|
|
workflow = interfaceDbChat.getWorkflow(workflowId)
|
|
if not workflow:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Workflow with ID {workflowId} not found"
|
|
)
|
|
|
|
# Get unified chat data for this workflow
|
|
chatData = interfaceDbChat.getUnifiedChatData(workflowId, None)
|
|
|
|
return {
|
|
"workflow": workflow,
|
|
"chatData": chatData
|
|
}
|
|
|
|
# Otherwise, return paginated list of workflows
|
|
# Parse pagination parameter
|
|
paginationParams = None
|
|
if pagination:
|
|
try:
|
|
paginationDict = json.loads(pagination)
|
|
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid pagination parameter: {str(e)}"
|
|
)
|
|
|
|
# Get all workflows filtered by mandateId (RBAC handles this automatically)
|
|
# We get all workflows first to filter by workflowMode before pagination
|
|
all_workflows = interfaceDbChat.getWorkflows(pagination=None)
|
|
|
|
# Filter to only include chatbot workflows
|
|
chatbot_workflows_data = [
|
|
wf for wf in all_workflows
|
|
if wf.get("workflowMode") == WorkflowModeEnum.WORKFLOW_CHATBOT.value
|
|
]
|
|
|
|
# Apply pagination if requested
|
|
if paginationParams:
|
|
# Apply sorting if provided
|
|
if paginationParams.sort:
|
|
chatbot_workflows_data = interfaceDbChat._applySorting(chatbot_workflows_data, paginationParams.sort)
|
|
|
|
# Count total items after filtering
|
|
totalItems = len(chatbot_workflows_data)
|
|
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
|
endIdx = startIdx + paginationParams.pageSize
|
|
workflows_data = chatbot_workflows_data[startIdx:endIdx]
|
|
else:
|
|
workflows_data = chatbot_workflows_data
|
|
totalItems = len(chatbot_workflows_data)
|
|
totalPages = 1
|
|
|
|
# Convert raw dictionaries to ChatWorkflow objects
|
|
workflows = []
|
|
for workflow_data in workflows_data:
|
|
try:
|
|
# Load the workflow properly
|
|
workflow = interfaceDbChat.getWorkflow(workflow_data["id"])
|
|
if workflow:
|
|
workflows.append(workflow)
|
|
except Exception as e:
|
|
logger.warning(f"Error loading workflow {workflow_data.get('id')}: {e}")
|
|
continue
|
|
|
|
# Create paginated response
|
|
from modules.datamodels.datamodelPagination import PaginationMetadata
|
|
metadata = PaginationMetadata(
|
|
currentPage=paginationParams.page if paginationParams else 1,
|
|
pageSize=paginationParams.pageSize if paginationParams else len(workflows),
|
|
totalItems=totalItems,
|
|
totalPages=totalPages,
|
|
sort=paginationParams.sort if paginationParams else [],
|
|
filters=paginationParams.filters if paginationParams else None
|
|
)
|
|
|
|
return PaginatedResponse(
|
|
items=workflows,
|
|
pagination=metadata
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting chatbot threads: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error getting chatbot threads: {str(e)}"
|
|
)
|
|
|