119 lines
3.3 KiB
Python
119 lines
3.3 KiB
Python
"""Router for chat related endpoints."""
|
|
|
|
import logging
|
|
import time
|
|
|
|
from fastapi import HTTPException
|
|
from fastapi import APIRouter, Depends
|
|
from fastapi.responses import StreamingResponse
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_core.embeddings import Embeddings as LCEmbeddings
|
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
|
|
from src.auth.dependencies import authenticate
|
|
from src.chat.schemas import (
|
|
PostChatMessageRequest,
|
|
PostChatMessageResponse,
|
|
)
|
|
from src.common.errors import ErrorResponse
|
|
from src.dependencies import (
|
|
get_embeddings,
|
|
get_chatmodel,
|
|
get_chatmemory,
|
|
)
|
|
|
|
|
|
from src.chat import service as chat_service
|
|
|
|
# Set up router
|
|
router = APIRouter()
|
|
|
|
# Set up logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@router.post(
|
|
"/message",
|
|
response_model=PostChatMessageResponse,
|
|
responses={
|
|
200: {"model": PostChatMessageResponse},
|
|
400: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
async def post_message(
|
|
request: PostChatMessageRequest,
|
|
embeddings: LCEmbeddings = Depends(get_embeddings),
|
|
chatmodel: ChatAnthropic = Depends(get_chatmodel),
|
|
chatmemory: SqliteSaver = Depends(get_chatmemory),
|
|
username: str = Depends(authenticate),
|
|
) -> PostChatMessageResponse:
|
|
"""Endpoint to send a chat message.
|
|
|
|
Args:
|
|
request: The chat message request.
|
|
embeddings: The embeddings model.
|
|
chatmodel: The chat model.
|
|
chatmemory: The chat memory.
|
|
username: str = Depends(authenticate)
|
|
|
|
Returns:
|
|
The response containing the full chat message history and thread ID.
|
|
"""
|
|
logger.info(f"Received message: {request.message} for thread {request.thread}")
|
|
|
|
# TODO: Ratelimits / Credits tbd.
|
|
|
|
response = await chat_service.post_message(
|
|
thread_id=request.thread,
|
|
message=request.message,
|
|
chatmodel=chatmodel,
|
|
chatmemory=chatmemory,
|
|
embeddings=embeddings,
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@router.post("/message/stream")
|
|
async def post_message_stream(
|
|
request: PostChatMessageRequest,
|
|
embeddings: LCEmbeddings = Depends(get_embeddings),
|
|
chatmodel: ChatAnthropic = Depends(get_chatmodel),
|
|
chatmemory: SqliteSaver = Depends(get_chatmemory),
|
|
username: str = Depends(authenticate),
|
|
) -> StreamingResponse:
|
|
"""Endpoint to send a chat message with streaming progress updates.
|
|
|
|
Args:
|
|
request: The chat message request.
|
|
embeddings: The embeddings model.
|
|
chatmodel: The chat model.
|
|
chatmemory: The chat memory.
|
|
username: str = Depends(authenticate)
|
|
|
|
Returns:
|
|
StreamingResponse with Server-Sent Events for progress updates.
|
|
"""
|
|
logger.info(
|
|
f"Received streaming message: {request.message} for thread {request.thread}"
|
|
)
|
|
|
|
# time.sleep(5) # slight delay to improve UX
|
|
|
|
# raise HTTPException(status_code=501, detail="Bitte erneut versuchen.")
|
|
|
|
return StreamingResponse(
|
|
chat_service.post_message_stream(
|
|
thread_id=request.thread,
|
|
message=request.message,
|
|
chatmodel=chatmodel,
|
|
chatmemory=chatmemory,
|
|
embeddings=embeddings,
|
|
),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
},
|
|
)
|