wiki/implementation/Chatbot/legacy/router.py

119 lines
3.3 KiB
Python

"""Router for chat related endpoints."""
import logging
import time
from fastapi import HTTPException
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from langchain_anthropic import ChatAnthropic
from langchain_core.embeddings import Embeddings as LCEmbeddings
from langgraph.checkpoint.sqlite import SqliteSaver
from src.auth.dependencies import authenticate
from src.chat.schemas import (
PostChatMessageRequest,
PostChatMessageResponse,
)
from src.common.errors import ErrorResponse
from src.dependencies import (
get_embeddings,
get_chatmodel,
get_chatmemory,
)
from src.chat import service as chat_service
# Set up router
router = APIRouter()
# Set up logging
logger = logging.getLogger(__name__)
@router.post(
"/message",
response_model=PostChatMessageResponse,
responses={
200: {"model": PostChatMessageResponse},
400: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def post_message(
request: PostChatMessageRequest,
embeddings: LCEmbeddings = Depends(get_embeddings),
chatmodel: ChatAnthropic = Depends(get_chatmodel),
chatmemory: SqliteSaver = Depends(get_chatmemory),
username: str = Depends(authenticate),
) -> PostChatMessageResponse:
"""Endpoint to send a chat message.
Args:
request: The chat message request.
embeddings: The embeddings model.
chatmodel: The chat model.
chatmemory: The chat memory.
username: str = Depends(authenticate)
Returns:
The response containing the full chat message history and thread ID.
"""
logger.info(f"Received message: {request.message} for thread {request.thread}")
# TODO: Ratelimits / Credits tbd.
response = await chat_service.post_message(
thread_id=request.thread,
message=request.message,
chatmodel=chatmodel,
chatmemory=chatmemory,
embeddings=embeddings,
)
return response
@router.post("/message/stream")
async def post_message_stream(
request: PostChatMessageRequest,
embeddings: LCEmbeddings = Depends(get_embeddings),
chatmodel: ChatAnthropic = Depends(get_chatmodel),
chatmemory: SqliteSaver = Depends(get_chatmemory),
username: str = Depends(authenticate),
) -> StreamingResponse:
"""Endpoint to send a chat message with streaming progress updates.
Args:
request: The chat message request.
embeddings: The embeddings model.
chatmodel: The chat model.
chatmemory: The chat memory.
username: str = Depends(authenticate)
Returns:
StreamingResponse with Server-Sent Events for progress updates.
"""
logger.info(
f"Received streaming message: {request.message} for thread {request.thread}"
)
# time.sleep(5) # slight delay to improve UX
# raise HTTPException(status_code=501, detail="Bitte erneut versuchen.")
return StreamingResponse(
chat_service.post_message_stream(
thread_id=request.thread,
message=request.message,
chatmodel=chatmodel,
chatmemory=chatmemory,
embeddings=embeddings,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)