diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py index fc190205..ba67a28b 100644 --- a/modules/features/chatBot/database.py +++ b/modules/features/chatBot/database.py @@ -1,8 +1,10 @@ from typing import AsyncIterator +import uuid + from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import String +from sqlalchemy import String, Uuid class Base(DeclarativeBase): @@ -15,18 +17,23 @@ class UserThreadMapping(Base): Used to keep track of which user owns which chat thread. Also stores meta data like thread name. + + 1:N relationship between user and thread. A thread belongs to exactly one user. + A user can have multiple threads. + Thread_id is unique in the table. """ __tablename__ = "userThreads" - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - userId: Mapped[int] = mapped_column(nullable=False) + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + userId: Mapped[str] = mapped_column(String(255), nullable=False) threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) threadName: Mapped[str] = mapped_column(String(255), nullable=False) # Dependency that pulls the sessionmaker off app.state # This is set in app.py on startup in @asynccontextmanager -async def get_session(request: Request) -> AsyncIterator[AsyncSession]: +# TODO: If we use SQLAlchemy in other places, we can move this to a shared module +async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]: SessionLocal: async_sessionmaker[AsyncSession] = ( request.app.state.checkpoint_sessionmaker ) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index b88f1753..86442ae6 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -2,12 +2,16 @@ import json import logging -from typing import AsyncIterator, List +from typing import AsyncIterator, List, Optional + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model from modules.features.chatBot.utils.checkpointer import get_checkpointer from modules.features.chatBot.utils.toolRegistry import get_registry from modules.features.chatBot.utils import permissions +from modules.features.chatBot.database import UserThreadMapping from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse from modules.datamodels.datamodelUam import User @@ -17,6 +21,166 @@ from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) +async def save_thread_for_user( + *, + thread_id: str, + user: User, + session: AsyncSession, + thread_name: str = "New Chat", + title: str = "New Chat", +) -> None: + """Save a new chat thread mapping for the user. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for saving. + thread_name: The name of the chat thread. Defaults to "New Chat". + title: Optional title for the chat (currently unused). + """ + logger.info(f"Saving new thread {thread_id} for user {user.id}") + + # Create new mapping entry + new_mapping = UserThreadMapping( + userId=user.id, + threadId=thread_id, + threadName=thread_name, + ) + + session.add(new_mapping) + await session.commit() + + logger.info(f"Successfully saved thread {thread_id} for user {user.id}") + + +async def get_or_create_thread_for_user( + *, + thread_id: Optional[str], + user: User, + session: AsyncSession, + thread_name: str = "New Chat", +) -> str: + """Get an existing thread or create a new one for the user. + + If thread_id is provided, verifies it exists and belongs to the user. + If thread_id is None, generates a new thread_id and saves it. + + Args: + thread_id: Optional thread identifier. If None, creates a new thread. + user: The current user. + session: The database session for querying/saving. + thread_name: The name for the thread if creating new. Defaults to "New Chat". + + Returns: + The thread_id to use (either the provided one or newly created). + + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the provided thread_id does not exist. + """ + if thread_id: + # If the user provided a thread_id, verify it exists and belongs to them + await assure_thread_exists_and_belongs_to_user( + thread_id=thread_id, user=user, session=session + ) + logger.info(f"Using existing thread {thread_id} for user {user.id}") + return thread_id + else: + # Generate new thread_id if the user did not provide a thread_id + import uuid + + new_thread_id = f"thread_{uuid.uuid4()}" + await save_thread_for_user( + thread_id=new_thread_id, + user=user, + session=session, + thread_name=thread_name, + ) + logger.info(f"Created new thread {new_thread_id} for user {user.id}") + return new_thread_id + + +async def assure_thread_exists_and_belongs_to_user( + *, + thread_id: str, + user: User, + session: AsyncSession, +) -> None: + """Ensure that the given thread ID exists and belongs to the specified user. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for querying. + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the thread does not exist. + """ + # Query the database for the thread mapping + stmt = select(UserThreadMapping).where(UserThreadMapping.threadId == thread_id) + result = await session.execute(stmt) + thread_mapping = result.scalar_one_or_none() + + # Check if thread exists + if thread_mapping is None: + logger.warning(f"Thread {thread_id} does not exist") + raise ValueError(f"Thread {thread_id} does not exist") + + # Check if thread belongs to the user + if thread_mapping.userId != user.id: + logger.warning( + f"User {user.id} attempted to access thread {thread_id} " + f"belonging to user {thread_mapping.userId}" + ) + raise PermissionError( + f"You do not have permission to access thread {thread_id}" + ) + + logger.info(f"Thread {thread_id} verified for user {user.id}") + + +async def update_thread_name( + *, + thread_id: str, + user: User, + new_thread_name: str, + session: AsyncSession, +) -> None: + """Update the name of an existing chat thread. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + new_thread_name: The new name to set for the thread. + session: The database session for updating. + + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the thread does not exist. + """ + # Verify thread exists and belongs to user + await assure_thread_exists_and_belongs_to_user( + thread_id=thread_id, + user=user, + session=session, + ) + + logger.info( + f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}" + ) + + # Update the thread name + stmt = ( + update(UserThreadMapping) + .where(UserThreadMapping.threadId == thread_id) + .values(threadName=new_thread_name) + ) + await session.execute(stmt) + await session.commit() + + logger.info(f"Successfully updated thread {thread_id} name for user {user.id}") + + async def post_message( *, thread_id: str, diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 48f74e5b..8f3efd65 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -5,7 +5,11 @@ from typing import Any, Dict, List, Optional from datetime import datetime import logging import uuid +from sqlalchemy.ext.asyncio import AsyncSession + +from modules.features.chatBot.database import get_async_db_session +from modules.features.chatBot.service import get_or_create_thread_for_user from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelChatbot import ( ChatMessageRequest, @@ -38,6 +42,7 @@ async def post_chat_message_stream( request: Request, message_request: ChatMessageRequest, currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), ) -> StreamingResponse: """ Post a message to a chat thread with streaming progress updates. @@ -46,12 +51,12 @@ async def post_chat_message_stream( Returns Server-Sent Events (SSE) stream with status updates and final response. """ try: - # TODO: Add helper here, if no thread id is provided, add entry in mapping table. - - # TODO: If not provided, create new thread in LangGraph's checkpointer, and add it to mapping table. - - # Generate or use existing thread_id - thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" + # Get or create thread using helper function + thread_id = await get_or_create_thread_for_user( + thread_id=message_request.thread_id, + user=currentUser, + session=session, + ) logger.info( f"User {currentUser.id} posted streaming message to thread {thread_id}" @@ -87,6 +92,7 @@ async def post_chat_message( request: Request, message_request: ChatMessageRequest, currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), ) -> ChatMessageResponse: """ Post a message to a chat thread and get assistant response (non-streaming). @@ -95,8 +101,12 @@ async def post_chat_message( For streaming updates, use the /message/stream endpoint instead. """ try: - # Generate or use existing thread_id - thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" + # Get or create thread using helper function + thread_id = await get_or_create_thread_for_user( + thread_id=message_request.thread_id, + user=currentUser, + session=session, + ) logger.info(f"User {currentUser.id} posted message to thread {thread_id}")