585 lines
25 KiB
Python
585 lines
25 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Chatbot routes for the backend API.
|
|
Implements simple chatbot endpoints using direct AI center calls via chatbot feature.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
import asyncio
|
|
import math
|
|
import uuid
|
|
from typing import Optional, Any, Dict, Union
|
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, status
|
|
from fastapi.responses import StreamingResponse
|
|
from modules.shared.timeUtils import parseTimestamp, getUtcTimestamp
|
|
|
|
# Import auth modules
|
|
from modules.auth import limiter, getRequestContext, RequestContext
|
|
|
|
# Import interfaces
|
|
from . import interfaceFeatureChatbot as interfaceDbChat
|
|
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
|
|
|
# Import models
|
|
from .datamodelFeatureChatbot import ChatWorkflow, UserInputRequest, WorkflowModeEnum
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
|
|
|
# Import chatbot feature
|
|
from . import chatProcess
|
|
from .eventManager import get_event_manager
|
|
|
|
# Import workflow control functions
|
|
from modules.workflows.automation import chatStop
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create router for chatbot endpoints
|
|
router = APIRouter(
|
|
prefix="/api/chatbot",
|
|
tags=["Chatbot"],
|
|
responses={404: {"description": "Not found"}}
|
|
)
|
|
|
|
def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None):
|
|
"""Get chatbot interface with instance context."""
|
|
mandateId = str(context.mandateId) if context.mandateId else None
|
|
return interfaceDbChat.getInterface(
|
|
context.user,
|
|
mandateId=mandateId,
|
|
featureInstanceId=instanceId
|
|
)
|
|
|
|
|
|
async def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
|
"""
|
|
Validate that the user has access to the feature instance.
|
|
Returns the mandateId for the instance.
|
|
|
|
Args:
|
|
instanceId: The FeatureInstance ID from URL
|
|
context: The request context with user info
|
|
|
|
Returns:
|
|
mandateId of the instance
|
|
|
|
Raises:
|
|
HTTPException 404 if instance not found
|
|
HTTPException 403 if user doesn't have access
|
|
"""
|
|
rootInterface = getRootInterface()
|
|
featureInterface = getFeatureInterface(rootInterface.db)
|
|
|
|
instance = featureInterface.getFeatureInstance(instanceId)
|
|
if not instance:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Feature instance '{instanceId}' not found"
|
|
)
|
|
|
|
# Verify it's a chatbot instance
|
|
if instance.featureCode != "chatbot":
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Instance '{instanceId}' is not a chatbot instance"
|
|
)
|
|
|
|
# Verify user has access to this instance
|
|
if not context.isSysAdmin:
|
|
# Check if user has FeatureAccess for this instance
|
|
featureAccesses = rootInterface.getFeatureAccessesForUser(str(context.user.id))
|
|
hasAccess = any(
|
|
str(fa.featureInstanceId) == instanceId and fa.enabled
|
|
for fa in featureAccesses
|
|
)
|
|
if not hasAccess:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Access denied to feature instance '{instanceId}'"
|
|
)
|
|
|
|
return str(instance.mandateId)
|
|
|
|
# Chatbot streaming endpoint (SSE)
|
|
@router.post("/{instanceId}/start/stream")
|
|
@limiter.limit("120/minute")
|
|
async def stream_chatbot_start(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue (can also be in request body)"),
|
|
userInput: UserInputRequest = Body(...),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> StreamingResponse:
|
|
"""
|
|
Starts a new chatbot workflow or continues an existing one with SSE streaming.
|
|
Streams progress updates in real-time via Server-Sent Events.
|
|
|
|
workflowId can be provided either:
|
|
- As a query parameter: /api/chatbot/{instanceId}/start/stream?workflowId=xxx
|
|
- In the request body as part of UserInputRequest
|
|
- Query parameter takes precedence if both are provided
|
|
"""
|
|
# Validate instance access
|
|
mandateId = await _validateInstanceAccess(instanceId, context)
|
|
|
|
event_manager = get_event_manager()
|
|
|
|
try:
|
|
# Use workflowId from query parameter if provided, otherwise from request body
|
|
final_workflow_id = workflowId or userInput.workflowId
|
|
|
|
# Start background processing (this will create the workflow and event queue)
|
|
# Pass featureInstanceId to chatProcess
|
|
workflow = await chatProcess(context.user, mandateId, userInput, final_workflow_id, featureInstanceId=instanceId)
|
|
|
|
# Check if workflow was created successfully
|
|
if not workflow:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Failed to create or load workflow"
|
|
)
|
|
|
|
# Get event queue for the workflow
|
|
queue = event_manager.get_queue(workflow.id)
|
|
if not queue:
|
|
# Create queue if it doesn't exist
|
|
queue = event_manager.create_queue(workflow.id)
|
|
|
|
async def event_stream():
|
|
"""Async generator for SSE events - pure event-driven streaming (no polling)."""
|
|
try:
|
|
# Get interface for initial data and status checks
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
|
|
# Get current workflow to check if resuming and get current round
|
|
current_workflow = interfaceDbChat.getWorkflow(workflow.id)
|
|
current_round = current_workflow.currentRound if current_workflow else None
|
|
is_resuming = final_workflow_id is not None and current_round and current_round > 1
|
|
|
|
# Send initial chat data (exact format as chatData endpoint) - only once at start
|
|
try:
|
|
chatData = interfaceDbChat.getUnifiedChatData(workflow.id, None)
|
|
if chatData.get("items"):
|
|
# Filter items by round number if resuming
|
|
filtered_items = []
|
|
for item in chatData["items"]:
|
|
if is_resuming and current_round:
|
|
# Get round number from item
|
|
item_round = None
|
|
item_data = item.get("item")
|
|
if item_data:
|
|
# Handle both dict and object access
|
|
if isinstance(item_data, dict):
|
|
item_round = item_data.get("roundNumber")
|
|
elif hasattr(item_data, "roundNumber"):
|
|
item_round = item_data.roundNumber
|
|
|
|
# When resuming, only include items from current round onwards
|
|
# Exclude items without roundNumber (they're from old rounds before roundNumber was added)
|
|
# Exclude items with roundNumber < current_round (from previous rounds)
|
|
if item_round is None or item_round < current_round:
|
|
continue # Skip items from previous rounds or without round info
|
|
|
|
filtered_items.append(item)
|
|
|
|
# Emit filtered items
|
|
for item in filtered_items:
|
|
# Convert Pydantic models to dicts for JSON serialization
|
|
serializable_item = {
|
|
"type": item.get("type"),
|
|
"createdAt": item.get("createdAt"),
|
|
"item": item.get("item").dict() if hasattr(item.get("item"), "dict") else item.get("item")
|
|
}
|
|
# Emit item directly in exact chatData format: {type, createdAt, item}
|
|
yield f"data: {json.dumps(serializable_item)}\n\n"
|
|
except Exception as e:
|
|
logger.warning(f"Error fetching initial chat data: {e}")
|
|
|
|
# Keepalive interval (30 seconds)
|
|
keepalive_interval = 30.0
|
|
last_keepalive = asyncio.get_event_loop().time()
|
|
|
|
# Status check interval (check workflow status every 5 seconds - less frequent since we're event-driven)
|
|
status_check_interval = 5.0
|
|
last_status_check = asyncio.get_event_loop().time()
|
|
|
|
# Stream events until completion or timeout - pure event-driven (no polling)
|
|
timeout = 300.0 # 5 minutes max
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
while True:
|
|
# Check timeout
|
|
elapsed = asyncio.get_event_loop().time() - start_time
|
|
if elapsed > timeout:
|
|
logger.info(f"Stream timeout for workflow {workflow.id}")
|
|
break
|
|
|
|
# Check for client disconnection
|
|
if await request.is_disconnected():
|
|
logger.info(f"Client disconnected for workflow {workflow.id}")
|
|
break
|
|
|
|
current_time = asyncio.get_event_loop().time()
|
|
|
|
# Periodically check workflow status (less frequent since we're event-driven)
|
|
if current_time - last_status_check >= status_check_interval:
|
|
try:
|
|
current_workflow = interfaceDbChat.getWorkflow(workflow.id)
|
|
if current_workflow and current_workflow.status == "stopped":
|
|
logger.info(f"Workflow {workflow.id} was stopped, closing stream")
|
|
break
|
|
except Exception as e:
|
|
logger.warning(f"Error checking workflow status: {e}")
|
|
last_status_check = current_time
|
|
|
|
# Get event from queue (pure event-driven - no polling database)
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=1.0)
|
|
|
|
# Handle different event types
|
|
event_type = event.get("type")
|
|
event_data = event.get("data", {})
|
|
|
|
# Emit chatdata events (messages, logs, stats) in exact chatData format
|
|
if event_type == "chatdata" and event_data:
|
|
# Emit item directly in exact chatData format: {type, createdAt, item}
|
|
chatdata_item = event_data
|
|
# Ensure item field is serializable (convert Pydantic models to dicts)
|
|
if isinstance(chatdata_item, dict) and "item" in chatdata_item:
|
|
item_obj = chatdata_item.get("item")
|
|
if hasattr(item_obj, "dict"):
|
|
chatdata_item = chatdata_item.copy()
|
|
chatdata_item["item"] = item_obj.dict()
|
|
yield f"data: {json.dumps(chatdata_item)}\n\n"
|
|
|
|
# Handle completion/stopped events to close stream
|
|
elif event_type == "complete":
|
|
logger.info(f"Workflow {workflow.id} completed, closing stream")
|
|
break
|
|
elif event_type == "stopped":
|
|
logger.info(f"Workflow {workflow.id} stopped, closing stream")
|
|
break
|
|
elif event_type == "error" and event.get("step") == "error":
|
|
logger.warning(f"Workflow {workflow.id} error, closing stream")
|
|
break
|
|
|
|
last_keepalive = current_time
|
|
except asyncio.TimeoutError:
|
|
# Send keepalive if needed (no events received, but keep connection alive)
|
|
current_time = asyncio.get_event_loop().time()
|
|
if current_time - last_keepalive >= keepalive_interval:
|
|
yield f": keepalive\n\n"
|
|
last_keepalive = current_time
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in event stream: {e}")
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in event stream generator: {e}", exc_info=True)
|
|
finally:
|
|
# Stream ends - cleanup handled by event manager
|
|
pass
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Disable buffering for nginx
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in stream_chatbot_start: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=str(e)
|
|
)
|
|
|
|
|
|
# Workflow stop endpoint
|
|
@router.post("/{instanceId}/stop/{workflowId}", response_model=ChatWorkflow)
|
|
@limiter.limit("120/minute")
|
|
async def stop_chatbot(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> ChatWorkflow:
|
|
"""Stops a running chatbot workflow."""
|
|
# Validate instance access
|
|
await _validateInstanceAccess(instanceId, context)
|
|
|
|
try:
|
|
# Get chatbot interface with instance context
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
|
|
# Get workflow to verify it exists and belongs to this instance
|
|
workflow = interfaceDbChat.getWorkflow(workflowId)
|
|
if not workflow:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Workflow {workflowId} not found"
|
|
)
|
|
|
|
# Verify workflow belongs to this instance
|
|
if workflow.featureInstanceId and workflow.featureInstanceId != instanceId:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Workflow {workflowId} does not belong to instance {instanceId}"
|
|
)
|
|
|
|
# Update workflow status to stopped
|
|
interfaceDbChat.updateWorkflow(workflowId, {
|
|
"status": "stopped",
|
|
"lastActivity": getUtcTimestamp()
|
|
})
|
|
|
|
# Store log entry
|
|
interfaceDbChat.createLog({
|
|
"id": f"log_{uuid.uuid4()}",
|
|
"workflowId": workflowId,
|
|
"message": "Workflow stopped by user",
|
|
"type": "warning",
|
|
"status": "stopped",
|
|
"timestamp": getUtcTimestamp(),
|
|
"roundNumber": workflow.currentRound if workflow else 1
|
|
})
|
|
|
|
# Reload workflow to return updated version
|
|
workflow = interfaceDbChat.getWorkflow(workflowId)
|
|
|
|
# Emit stopped event to active streams
|
|
event_manager = get_event_manager()
|
|
await event_manager.emit_event(
|
|
context_id=workflowId,
|
|
event_type="stopped",
|
|
data={"workflowId": workflowId},
|
|
event_category="workflow",
|
|
message="Workflow stopped by user",
|
|
step="stopped"
|
|
)
|
|
logger.info(f"Stopped workflow {workflowId} and emitted stopped event")
|
|
|
|
return workflow
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in stop_chatbot: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=str(e)
|
|
)
|
|
|
|
# Delete chatbot workflow endpoint
|
|
@router.delete("/{instanceId}/{workflowId}", response_model=Dict[str, Any])
|
|
@limiter.limit("120/minute")
|
|
async def delete_chatbot(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: str = Path(..., description="ID of the workflow to delete"),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> Dict[str, Any]:
|
|
"""Deletes a chatbot workflow and its associated data."""
|
|
# Validate instance access
|
|
mandateId = await _validateInstanceAccess(instanceId, context)
|
|
|
|
try:
|
|
# Get service center
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
|
|
# Check workflow access and permission using RBAC
|
|
workflows = getRecordsetWithRBAC(
|
|
interfaceDbChat.db,
|
|
ChatWorkflow,
|
|
context.user,
|
|
recordFilter={"id": workflowId}
|
|
)
|
|
if not workflows:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Workflow with ID {workflowId} not found"
|
|
)
|
|
|
|
workflow_data = workflows[0]
|
|
|
|
# Check if workflow is a chatbot workflow
|
|
if workflow_data.get("workflowMode") != WorkflowModeEnum.WORKFLOW_CHATBOT.value:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Workflow {workflowId} is not a chatbot workflow"
|
|
)
|
|
|
|
# Verify workflow belongs to this instance
|
|
workflow_instance_id = workflow_data.get("featureInstanceId")
|
|
if workflow_instance_id != instanceId:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Workflow {workflowId} does not belong to instance '{instanceId}'"
|
|
)
|
|
|
|
# Check if user has permission to delete using RBAC
|
|
if not interfaceDbChat.checkRbacPermission(ChatWorkflow, "delete", workflowId):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You don't have permission to delete this workflow"
|
|
)
|
|
|
|
# Delete workflow
|
|
success = interfaceDbChat.deleteWorkflow(workflowId)
|
|
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to delete workflow"
|
|
)
|
|
|
|
return {
|
|
"id": workflowId,
|
|
"message": "Chatbot workflow and associated data deleted successfully"
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in delete_chatbot: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error deleting chatbot workflow: {str(e)}"
|
|
)
|
|
|
|
# List chatbot threads/workflows or get specific thread details
|
|
@router.get("/{instanceId}/threads")
|
|
@limiter.limit("120/minute")
|
|
async def get_chatbot_threads(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: Optional[str] = Query(None, description="Optional workflow ID to get details and chat data for a specific thread"),
|
|
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object (only used when workflowId is not provided)"),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> Union[PaginatedResponse[ChatWorkflow], Dict[str, Any]]:
|
|
"""
|
|
List all chatbot workflows (threads) for the current user, or get details and chat data for a specific thread.
|
|
|
|
- If workflowId is provided: Returns the workflow details and all chat data (messages, logs, stats)
|
|
- If workflowId is not provided: Returns a paginated list of all workflows
|
|
"""
|
|
# Validate instance access
|
|
mandateId = await _validateInstanceAccess(instanceId, context)
|
|
|
|
try:
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
|
|
# If workflowId is provided, return single workflow with chat data
|
|
if workflowId:
|
|
workflow = interfaceDbChat.getWorkflow(workflowId)
|
|
if not workflow:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Workflow with ID {workflowId} not found"
|
|
)
|
|
|
|
# Normalize workflow data to match ChatWorkflow model requirements
|
|
# Convert workflow object to dict if needed, and normalize None values
|
|
if hasattr(workflow, 'model_dump'):
|
|
workflow_dict = workflow.model_dump()
|
|
elif hasattr(workflow, 'dict'):
|
|
workflow_dict = workflow.dict()
|
|
elif isinstance(workflow, dict):
|
|
workflow_dict = dict(workflow)
|
|
else:
|
|
workflow_dict = workflow
|
|
|
|
# Set maxSteps to default value of 10 if None (as per ChatWorkflow model)
|
|
if workflow_dict.get("maxSteps") is None:
|
|
workflow_dict["maxSteps"] = 10
|
|
|
|
# Get unified chat data for this workflow
|
|
chatData = interfaceDbChat.getUnifiedChatData(workflowId, None)
|
|
|
|
return {
|
|
"workflow": workflow_dict,
|
|
"chatData": chatData
|
|
}
|
|
|
|
# Otherwise, return paginated list of workflows
|
|
# Parse pagination parameter
|
|
paginationParams = None
|
|
if pagination:
|
|
try:
|
|
paginationDict = json.loads(pagination)
|
|
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid pagination parameter: {str(e)}"
|
|
)
|
|
|
|
# Get all workflows filtered by mandateId (RBAC handles this automatically)
|
|
# We get all workflows first to filter by workflowMode before pagination
|
|
all_workflows = interfaceDbChat.getWorkflows(pagination=None)
|
|
|
|
# Filter to only include chatbot workflows
|
|
chatbot_workflows_data = [
|
|
wf for wf in all_workflows
|
|
if wf.get("workflowMode") == WorkflowModeEnum.WORKFLOW_CHATBOT.value
|
|
]
|
|
|
|
# Apply pagination if requested
|
|
if paginationParams:
|
|
# Apply sorting if provided
|
|
if paginationParams.sort:
|
|
chatbot_workflows_data = interfaceDbChat._applySorting(chatbot_workflows_data, paginationParams.sort)
|
|
|
|
# Count total items after filtering
|
|
totalItems = len(chatbot_workflows_data)
|
|
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
|
endIdx = startIdx + paginationParams.pageSize
|
|
workflows = chatbot_workflows_data[startIdx:endIdx]
|
|
else:
|
|
workflows = chatbot_workflows_data
|
|
totalItems = len(chatbot_workflows_data)
|
|
totalPages = 1
|
|
|
|
# Normalize workflow data to match ChatWorkflow model requirements
|
|
# Convert None values to defaults before response validation
|
|
normalized_workflows = []
|
|
for wf in workflows:
|
|
normalized_wf = dict(wf) # Create a copy
|
|
# Set maxSteps to default value of 10 if None (as per ChatWorkflow model)
|
|
if normalized_wf.get("maxSteps") is None:
|
|
normalized_wf["maxSteps"] = 10
|
|
normalized_workflows.append(normalized_wf)
|
|
|
|
# Create paginated response
|
|
metadata = PaginationMetadata(
|
|
currentPage=paginationParams.page if paginationParams else 1,
|
|
pageSize=paginationParams.pageSize if paginationParams else len(workflows),
|
|
totalItems=totalItems,
|
|
totalPages=totalPages,
|
|
sort=paginationParams.sort if paginationParams else [],
|
|
filters=paginationParams.filters if paginationParams else None
|
|
)
|
|
|
|
return PaginatedResponse(
|
|
items=normalized_workflows,
|
|
pagination=metadata
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting chatbot threads: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error getting chatbot threads: {str(e)}"
|
|
)
|