350 lines
15 KiB
Python
350 lines
15 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Chatbot V2 routes - context-aware chat with file upload and extraction.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import math
|
|
import logging
|
|
import uuid
|
|
from typing import Optional, Any, Dict, Union
|
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, status, UploadFile, File
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
from modules.auth import limiter, getRequestContext, RequestContext
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
|
from modules.datamodels.datamodelChat import UserInputRequest
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
|
|
from . import interfaceFeatureChatbotV2 as interfaceDbChat
|
|
from .interfaceFeatureChatbotV2 import getInterface as getChatbotV2Interface
|
|
from .datamodelFeatureChatbotV2 import ChatbotV2Conversation
|
|
from .serviceChatbotV2 import uploadAndExtract, chatProcessV2
|
|
from modules.features.chatbot.streaming.events import get_event_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/chatbotv2",
|
|
tags=["Chatbot V2"],
|
|
responses={404: {"description": "Not found"}}
|
|
)
|
|
|
|
|
|
class UploadRequest(BaseModel):
|
|
"""Request body for file upload - files must be uploaded to central storage first."""
|
|
listFileId: list[str] = Field(default_factory=list, description="List of file IDs from central storage")
|
|
|
|
|
|
def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None):
|
|
"""Get ChatbotV2 interface with instance context."""
|
|
mandateId = str(context.mandateId) if context.mandateId else None
|
|
return getChatbotV2Interface(
|
|
context.user,
|
|
mandateId=mandateId,
|
|
featureInstanceId=instanceId
|
|
)
|
|
|
|
|
|
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
|
"""Validate that the user has access to the feature instance."""
|
|
rootInterface = getRootInterface()
|
|
featureInterface = getFeatureInterface(rootInterface.db)
|
|
instance = featureInterface.getFeatureInstance(instanceId)
|
|
if not instance:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Feature instance '{instanceId}' not found"
|
|
)
|
|
if instance.featureCode != "chatbotv2":
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Instance '{instanceId}' is not a chatbotv2 instance"
|
|
)
|
|
if not context.hasSysAdminRole:
|
|
featureAccesses = rootInterface.getFeatureAccessesForUser(str(context.user.id))
|
|
hasAccess = any(
|
|
str(fa.featureInstanceId) == instanceId and fa.enabled
|
|
for fa in featureAccesses
|
|
)
|
|
if not hasAccess:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Access denied to feature instance '{instanceId}'"
|
|
)
|
|
return str(instance.mandateId)
|
|
|
|
|
|
# =============================================================================
|
|
# Upload - start extraction
|
|
# =============================================================================
|
|
@router.post("/{instanceId}/upload")
|
|
@limiter.limit("60/minute")
|
|
async def upload_files(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
body: UploadRequest = Body(...),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Upload files as context and start extraction.
|
|
Files must be uploaded to central storage first; pass their file IDs in listFileId.
|
|
Returns conversationId. Extraction runs in background; poll threads or use SSE for status.
|
|
"""
|
|
mandateId = _validateInstanceAccess(instanceId, context)
|
|
if not body.listFileId:
|
|
raise HTTPException(status_code=400, detail="listFileId is required and must not be empty")
|
|
try:
|
|
conversation = await uploadAndExtract(
|
|
context.user,
|
|
mandateId=mandateId,
|
|
instanceId=instanceId,
|
|
listFileId=body.listFileId
|
|
)
|
|
return {
|
|
"conversationId": conversation.id,
|
|
"status": conversation.status,
|
|
"message": "Extraction started. Poll GET /threads?workflowId={} for status.".format(conversation.id)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in upload_files: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# List threads - MUST be first to avoid /{instanceId}/{workflowId} matching
|
|
# =============================================================================
|
|
@router.get("/{instanceId}/threads")
|
|
@limiter.limit("120/minute")
|
|
def get_threads(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: Optional[str] = Query(None, description="Optional workflow/conversation ID for details"),
|
|
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> Union[PaginatedResponse, Dict[str, Any]]:
|
|
"""List conversations or get details for a specific one."""
|
|
_validateInstanceAccess(instanceId, context)
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
|
|
if workflowId:
|
|
conv = interfaceDbChat.getConversation(workflowId)
|
|
if not conv:
|
|
raise HTTPException(status_code=404, detail=f"Conversation {workflowId} not found")
|
|
workflow_dict = conv.model_dump()
|
|
chatData = interfaceDbChat.getUnifiedChatData(workflowId, None)
|
|
return {"workflow": workflow_dict, "chatData": chatData}
|
|
|
|
paginationParams = None
|
|
if pagination:
|
|
try:
|
|
paginationDict = json.loads(pagination)
|
|
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
|
|
|
all_convs = interfaceDbChat.getConversations(pagination=None)
|
|
# all_convs from getConversations can be list of dicts (from getRecordsetWithRBAC)
|
|
items = [c if isinstance(c, dict) else c.model_dump() for c in all_convs]
|
|
|
|
if paginationParams:
|
|
totalItems = len(items)
|
|
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
|
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
|
endIdx = startIdx + paginationParams.pageSize
|
|
workflows = items[startIdx:endIdx]
|
|
else:
|
|
workflows = items
|
|
totalItems = len(items)
|
|
totalPages = 1
|
|
|
|
metadata = PaginationMetadata(
|
|
currentPage=paginationParams.page if paginationParams else 1,
|
|
pageSize=paginationParams.pageSize if paginationParams else len(workflows),
|
|
totalItems=totalItems,
|
|
totalPages=totalPages,
|
|
sort=paginationParams.sort if paginationParams else [],
|
|
filters=paginationParams.filters if paginationParams else None
|
|
)
|
|
return PaginatedResponse(items=workflows, pagination=metadata)
|
|
|
|
|
|
# =============================================================================
|
|
# Start/continue chat (SSE stream)
|
|
# =============================================================================
|
|
@router.post("/{instanceId}/start/stream")
|
|
@limiter.limit("120/minute")
|
|
async def stream_chat_start(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: Optional[str] = Query(None, description="Optional conversation ID to continue"),
|
|
userInput: UserInputRequest = Body(...),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> StreamingResponse:
|
|
"""Start or continue a chat with SSE streaming."""
|
|
mandateId = _validateInstanceAccess(instanceId, context)
|
|
event_manager = get_event_manager()
|
|
final_workflow_id = workflowId or userInput.workflowId
|
|
|
|
try:
|
|
workflow = await chatProcessV2(
|
|
context.user,
|
|
mandateId=mandateId,
|
|
userInput=userInput,
|
|
conversationId=final_workflow_id,
|
|
instanceId=instanceId
|
|
)
|
|
if not workflow:
|
|
raise HTTPException(status_code=500, detail="Failed to create or load workflow")
|
|
|
|
queue = event_manager.get_queue(workflow.id)
|
|
if not queue:
|
|
queue = event_manager.create_queue(workflow.id)
|
|
|
|
async def event_stream():
|
|
try:
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
chatData = interfaceDbChat.getUnifiedChatData(workflow.id, None)
|
|
if chatData.get("items"):
|
|
for item in chatData["items"]:
|
|
ser = {
|
|
"type": item.get("type"),
|
|
"createdAt": item.get("createdAt"),
|
|
"item": item.get("item").model_dump() if hasattr(item.get("item"), "model_dump") else item.get("item")
|
|
}
|
|
yield f"data: {json.dumps(ser)}\n\n"
|
|
|
|
keepalive_interval = 30.0
|
|
last_keepalive = asyncio.get_event_loop().time()
|
|
status_check_interval = 5.0
|
|
last_status_check = asyncio.get_event_loop().time()
|
|
timeout = 300.0
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
while True:
|
|
elapsed = asyncio.get_event_loop().time() - start_time
|
|
if elapsed > timeout:
|
|
break
|
|
if await request.is_disconnected():
|
|
break
|
|
current_time = asyncio.get_event_loop().time()
|
|
|
|
if current_time - last_status_check >= status_check_interval:
|
|
try:
|
|
cw = interfaceDbChat.getConversation(workflow.id)
|
|
if cw and cw.status == "stopped":
|
|
break
|
|
except Exception:
|
|
pass
|
|
last_status_check = current_time
|
|
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=1.0)
|
|
event_type = event.get("type")
|
|
event_data = event.get("data", {})
|
|
|
|
if event_type == "chatdata" and event_data:
|
|
if event_data.get("type") == "status":
|
|
yield f"data: {json.dumps({'type': 'status', 'label': event_data.get('label', '')})}\n\n"
|
|
else:
|
|
item = event_data
|
|
if isinstance(item, dict) and "item" in item:
|
|
obj = item.get("item")
|
|
if hasattr(obj, "model_dump"):
|
|
item = {**item, "item": obj.model_dump()}
|
|
yield f"data: {json.dumps(item)}\n\n"
|
|
elif event_type in ("complete", "stopped"):
|
|
break
|
|
elif event_type == "error" and event.get("step") == "error":
|
|
break
|
|
last_keepalive = current_time
|
|
except asyncio.TimeoutError:
|
|
if current_time - last_keepalive >= keepalive_interval:
|
|
yield ": keepalive\n\n"
|
|
last_keepalive = current_time
|
|
except Exception as e:
|
|
logger.error(f"Error in event stream: {e}")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in event stream generator: {e}", exc_info=True)
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no"
|
|
}
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in stream_chat_start: {str(e)}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Stop chat
|
|
# =============================================================================
|
|
@router.post("/{instanceId}/stop/{workflowId}")
|
|
@limiter.limit("120/minute")
|
|
async def stop_chat(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: str = Path(..., description="Conversation ID to stop"),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> ChatbotV2Conversation:
|
|
"""Stop a running chat."""
|
|
_validateInstanceAccess(instanceId, context)
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
conv = interfaceDbChat.getConversation(workflowId)
|
|
if not conv:
|
|
raise HTTPException(status_code=404, detail=f"Conversation {workflowId} not found")
|
|
interfaceDbChat.updateConversation(workflowId, {"status": "stopped", "lastActivity": getUtcTimestamp()})
|
|
interfaceDbChat.createLog({
|
|
"conversationId": workflowId,
|
|
"message": "Workflow stopped by user",
|
|
"type": "warning",
|
|
"status": "stopped",
|
|
"timestamp": getUtcTimestamp()
|
|
})
|
|
event_manager = get_event_manager()
|
|
await event_manager.emit_event(
|
|
context_id=workflowId,
|
|
event_type="stopped",
|
|
data={"workflowId": workflowId},
|
|
event_category="workflow",
|
|
message="Workflow stopped by user",
|
|
step="stopped"
|
|
)
|
|
return interfaceDbChat.getConversation(workflowId)
|
|
|
|
|
|
# =============================================================================
|
|
# Delete conversation - use /conversations/{workflowId} to avoid
|
|
# /{instanceId}/{workflowId} matching GET /threads (workflowId="threads")
|
|
# =============================================================================
|
|
@router.delete("/{instanceId}/conversations/{workflowId}")
|
|
@limiter.limit("120/minute")
|
|
def delete_conversation(
|
|
request: Request,
|
|
instanceId: str = Path(..., description="Feature Instance ID"),
|
|
workflowId: str = Path(..., description="Conversation ID to delete"),
|
|
context: RequestContext = Depends(getRequestContext)
|
|
) -> Dict[str, Any]:
|
|
"""Delete a conversation and its data."""
|
|
_validateInstanceAccess(instanceId, context)
|
|
interfaceDbChat = _getServiceChat(context, instanceId)
|
|
conv = interfaceDbChat.getConversation(workflowId)
|
|
if not conv:
|
|
raise HTTPException(status_code=404, detail=f"Conversation {workflowId} not found")
|
|
success = interfaceDbChat.deleteConversation(workflowId)
|
|
if not success:
|
|
raise HTTPException(status_code=500, detail="Failed to delete conversation")
|
|
return {"id": workflowId, "message": "Conversation deleted successfully"}
|