feat: save chatbot threads to db

This commit is contained in:
Christopher Gondek 2025-10-08 14:41:29 +02:00
parent 30c3f9f7f1
commit b50dcc6c0f
3 changed files with 194 additions and 13 deletions

View file

@ -1,8 +1,10 @@
from typing import AsyncIterator from typing import AsyncIterator
import uuid
from fastapi import Request from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String from sqlalchemy import String, Uuid
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -15,18 +17,23 @@ class UserThreadMapping(Base):
Used to keep track of which user owns which chat thread. Used to keep track of which user owns which chat thread.
Also stores meta data like thread name. Also stores meta data like thread name.
1:N relationship between user and thread. A thread belongs to exactly one user.
A user can have multiple threads.
Thread_id is unique in the table.
""" """
__tablename__ = "userThreads" __tablename__ = "userThreads"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
userId: Mapped[int] = mapped_column(nullable=False) userId: Mapped[str] = mapped_column(String(255), nullable=False)
threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
threadName: Mapped[str] = mapped_column(String(255), nullable=False) threadName: Mapped[str] = mapped_column(String(255), nullable=False)
# Dependency that pulls the sessionmaker off app.state # Dependency that pulls the sessionmaker off app.state
# This is set in app.py on startup in @asynccontextmanager # This is set in app.py on startup in @asynccontextmanager
async def get_session(request: Request) -> AsyncIterator[AsyncSession]: # TODO: If we use SQLAlchemy in other places, we can move this to a shared module
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
SessionLocal: async_sessionmaker[AsyncSession] = ( SessionLocal: async_sessionmaker[AsyncSession] = (
request.app.state.checkpoint_sessionmaker request.app.state.checkpoint_sessionmaker
) )

View file

@ -2,12 +2,16 @@
import json import json
import logging import logging
from typing import AsyncIterator, List from typing import AsyncIterator, List, Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
from modules.features.chatBot.utils.checkpointer import get_checkpointer from modules.features.chatBot.utils.checkpointer import get_checkpointer
from modules.features.chatBot.utils.toolRegistry import get_registry from modules.features.chatBot.utils.toolRegistry import get_registry
from modules.features.chatBot.utils import permissions from modules.features.chatBot.utils import permissions
from modules.features.chatBot.database import UserThreadMapping
from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse
from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelUam import User
@ -17,6 +21,166 @@ from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def save_thread_for_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
title: str = "New Chat",
) -> None:
"""Save a new chat thread mapping for the user.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for saving.
thread_name: The name of the chat thread. Defaults to "New Chat".
title: Optional title for the chat (currently unused).
"""
logger.info(f"Saving new thread {thread_id} for user {user.id}")
# Create new mapping entry
new_mapping = UserThreadMapping(
userId=user.id,
threadId=thread_id,
threadName=thread_name,
)
session.add(new_mapping)
await session.commit()
logger.info(f"Successfully saved thread {thread_id} for user {user.id}")
async def get_or_create_thread_for_user(
*,
thread_id: Optional[str],
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
) -> str:
"""Get an existing thread or create a new one for the user.
If thread_id is provided, verifies it exists and belongs to the user.
If thread_id is None, generates a new thread_id and saves it.
Args:
thread_id: Optional thread identifier. If None, creates a new thread.
user: The current user.
session: The database session for querying/saving.
thread_name: The name for the thread if creating new. Defaults to "New Chat".
Returns:
The thread_id to use (either the provided one or newly created).
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the provided thread_id does not exist.
"""
if thread_id:
# If the user provided a thread_id, verify it exists and belongs to them
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id, user=user, session=session
)
logger.info(f"Using existing thread {thread_id} for user {user.id}")
return thread_id
else:
# Generate new thread_id if the user did not provide a thread_id
import uuid
new_thread_id = f"thread_{uuid.uuid4()}"
await save_thread_for_user(
thread_id=new_thread_id,
user=user,
session=session,
thread_name=thread_name,
)
logger.info(f"Created new thread {new_thread_id} for user {user.id}")
return new_thread_id
async def assure_thread_exists_and_belongs_to_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Ensure that the given thread ID exists and belongs to the specified user.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for querying.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
# Query the database for the thread mapping
stmt = select(UserThreadMapping).where(UserThreadMapping.threadId == thread_id)
result = await session.execute(stmt)
thread_mapping = result.scalar_one_or_none()
# Check if thread exists
if thread_mapping is None:
logger.warning(f"Thread {thread_id} does not exist")
raise ValueError(f"Thread {thread_id} does not exist")
# Check if thread belongs to the user
if thread_mapping.userId != user.id:
logger.warning(
f"User {user.id} attempted to access thread {thread_id} "
f"belonging to user {thread_mapping.userId}"
)
raise PermissionError(
f"You do not have permission to access thread {thread_id}"
)
logger.info(f"Thread {thread_id} verified for user {user.id}")
async def update_thread_name(
*,
thread_id: str,
user: User,
new_thread_name: str,
session: AsyncSession,
) -> None:
"""Update the name of an existing chat thread.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
new_thread_name: The new name to set for the thread.
session: The database session for updating.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
# Verify thread exists and belongs to user
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id,
user=user,
session=session,
)
logger.info(
f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}"
)
# Update the thread name
stmt = (
update(UserThreadMapping)
.where(UserThreadMapping.threadId == thread_id)
.values(threadName=new_thread_name)
)
await session.execute(stmt)
await session.commit()
logger.info(f"Successfully updated thread {thread_id} name for user {user.id}")
async def post_message( async def post_message(
*, *,
thread_id: str, thread_id: str,

View file

@ -5,7 +5,11 @@ from typing import Any, Dict, List, Optional
from datetime import datetime from datetime import datetime
import logging import logging
import uuid import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from modules.features.chatBot.database import get_async_db_session
from modules.features.chatBot.service import get_or_create_thread_for_user
from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelChatbot import ( from modules.datamodels.datamodelChatbot import (
ChatMessageRequest, ChatMessageRequest,
@ -38,6 +42,7 @@ async def post_chat_message_stream(
request: Request, request: Request,
message_request: ChatMessageRequest, message_request: ChatMessageRequest,
currentUser: User = Depends(getCurrentUser), currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> StreamingResponse: ) -> StreamingResponse:
""" """
Post a message to a chat thread with streaming progress updates. Post a message to a chat thread with streaming progress updates.
@ -46,12 +51,12 @@ async def post_chat_message_stream(
Returns Server-Sent Events (SSE) stream with status updates and final response. Returns Server-Sent Events (SSE) stream with status updates and final response.
""" """
try: try:
# TODO: Add helper here, if no thread id is provided, add entry in mapping table. # Get or create thread using helper function
thread_id = await get_or_create_thread_for_user(
# TODO: If not provided, create new thread in LangGraph's checkpointer, and add it to mapping table. thread_id=message_request.thread_id,
user=currentUser,
# Generate or use existing thread_id session=session,
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" )
logger.info( logger.info(
f"User {currentUser.id} posted streaming message to thread {thread_id}" f"User {currentUser.id} posted streaming message to thread {thread_id}"
@ -87,6 +92,7 @@ async def post_chat_message(
request: Request, request: Request,
message_request: ChatMessageRequest, message_request: ChatMessageRequest,
currentUser: User = Depends(getCurrentUser), currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ChatMessageResponse: ) -> ChatMessageResponse:
""" """
Post a message to a chat thread and get assistant response (non-streaming). Post a message to a chat thread and get assistant response (non-streaming).
@ -95,8 +101,12 @@ async def post_chat_message(
For streaming updates, use the /message/stream endpoint instead. For streaming updates, use the /message/stream endpoint instead.
""" """
try: try:
# Generate or use existing thread_id # Get or create thread using helper function
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" thread_id = await get_or_create_thread_for_user(
thread_id=message_request.thread_id,
user=currentUser,
session=session,
)
logger.info(f"User {currentUser.id} posted message to thread {thread_id}") logger.info(f"User {currentUser.id} posted message to thread {thread_id}")