feat: save chatbot threads to db
This commit is contained in:
parent
30c3f9f7f1
commit
b50dcc6c0f
3 changed files with 194 additions and 13 deletions
|
|
@ -1,8 +1,10 @@
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
import uuid
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
from sqlalchemy import String
|
from sqlalchemy import String, Uuid
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
|
|
@ -15,18 +17,23 @@ class UserThreadMapping(Base):
|
||||||
|
|
||||||
Used to keep track of which user owns which chat thread.
|
Used to keep track of which user owns which chat thread.
|
||||||
Also stores meta data like thread name.
|
Also stores meta data like thread name.
|
||||||
|
|
||||||
|
1:N relationship between user and thread. A thread belongs to exactly one user.
|
||||||
|
A user can have multiple threads.
|
||||||
|
Thread_id is unique in the table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "userThreads"
|
__tablename__ = "userThreads"
|
||||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
userId: Mapped[int] = mapped_column(nullable=False)
|
userId: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||||
threadName: Mapped[str] = mapped_column(String(255), nullable=False)
|
threadName: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
# Dependency that pulls the sessionmaker off app.state
|
# Dependency that pulls the sessionmaker off app.state
|
||||||
# This is set in app.py on startup in @asynccontextmanager
|
# This is set in app.py on startup in @asynccontextmanager
|
||||||
async def get_session(request: Request) -> AsyncIterator[AsyncSession]:
|
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
|
||||||
|
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||||
SessionLocal: async_sessionmaker[AsyncSession] = (
|
SessionLocal: async_sessionmaker[AsyncSession] = (
|
||||||
request.app.state.checkpoint_sessionmaker
|
request.app.state.checkpoint_sessionmaker
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,16 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncIterator, List
|
from typing import AsyncIterator, List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
|
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
|
||||||
from modules.features.chatBot.utils.checkpointer import get_checkpointer
|
from modules.features.chatBot.utils.checkpointer import get_checkpointer
|
||||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||||
from modules.features.chatBot.utils import permissions
|
from modules.features.chatBot.utils import permissions
|
||||||
|
from modules.features.chatBot.database import UserThreadMapping
|
||||||
from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse
|
from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse
|
||||||
from modules.datamodels.datamodelUam import User
|
from modules.datamodels.datamodelUam import User
|
||||||
|
|
||||||
|
|
@ -17,6 +21,166 @@ from modules.shared.configuration import APP_CONFIG
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def save_thread_for_user(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
user: User,
|
||||||
|
session: AsyncSession,
|
||||||
|
thread_name: str = "New Chat",
|
||||||
|
title: str = "New Chat",
|
||||||
|
) -> None:
|
||||||
|
"""Save a new chat thread mapping for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The unique identifier for the chat thread.
|
||||||
|
user: The current user.
|
||||||
|
session: The database session for saving.
|
||||||
|
thread_name: The name of the chat thread. Defaults to "New Chat".
|
||||||
|
title: Optional title for the chat (currently unused).
|
||||||
|
"""
|
||||||
|
logger.info(f"Saving new thread {thread_id} for user {user.id}")
|
||||||
|
|
||||||
|
# Create new mapping entry
|
||||||
|
new_mapping = UserThreadMapping(
|
||||||
|
userId=user.id,
|
||||||
|
threadId=thread_id,
|
||||||
|
threadName=thread_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(new_mapping)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Successfully saved thread {thread_id} for user {user.id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_thread_for_user(
|
||||||
|
*,
|
||||||
|
thread_id: Optional[str],
|
||||||
|
user: User,
|
||||||
|
session: AsyncSession,
|
||||||
|
thread_name: str = "New Chat",
|
||||||
|
) -> str:
|
||||||
|
"""Get an existing thread or create a new one for the user.
|
||||||
|
|
||||||
|
If thread_id is provided, verifies it exists and belongs to the user.
|
||||||
|
If thread_id is None, generates a new thread_id and saves it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Optional thread identifier. If None, creates a new thread.
|
||||||
|
user: The current user.
|
||||||
|
session: The database session for querying/saving.
|
||||||
|
thread_name: The name for the thread if creating new. Defaults to "New Chat".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The thread_id to use (either the provided one or newly created).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If the thread does not belong to the user.
|
||||||
|
ValueError: If the provided thread_id does not exist.
|
||||||
|
"""
|
||||||
|
if thread_id:
|
||||||
|
# If the user provided a thread_id, verify it exists and belongs to them
|
||||||
|
await assure_thread_exists_and_belongs_to_user(
|
||||||
|
thread_id=thread_id, user=user, session=session
|
||||||
|
)
|
||||||
|
logger.info(f"Using existing thread {thread_id} for user {user.id}")
|
||||||
|
return thread_id
|
||||||
|
else:
|
||||||
|
# Generate new thread_id if the user did not provide a thread_id
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
new_thread_id = f"thread_{uuid.uuid4()}"
|
||||||
|
await save_thread_for_user(
|
||||||
|
thread_id=new_thread_id,
|
||||||
|
user=user,
|
||||||
|
session=session,
|
||||||
|
thread_name=thread_name,
|
||||||
|
)
|
||||||
|
logger.info(f"Created new thread {new_thread_id} for user {user.id}")
|
||||||
|
return new_thread_id
|
||||||
|
|
||||||
|
|
||||||
|
async def assure_thread_exists_and_belongs_to_user(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
user: User,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure that the given thread ID exists and belongs to the specified user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The unique identifier for the chat thread.
|
||||||
|
user: The current user.
|
||||||
|
session: The database session for querying.
|
||||||
|
Raises:
|
||||||
|
PermissionError: If the thread does not belong to the user.
|
||||||
|
ValueError: If the thread does not exist.
|
||||||
|
"""
|
||||||
|
# Query the database for the thread mapping
|
||||||
|
stmt = select(UserThreadMapping).where(UserThreadMapping.threadId == thread_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
thread_mapping = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# Check if thread exists
|
||||||
|
if thread_mapping is None:
|
||||||
|
logger.warning(f"Thread {thread_id} does not exist")
|
||||||
|
raise ValueError(f"Thread {thread_id} does not exist")
|
||||||
|
|
||||||
|
# Check if thread belongs to the user
|
||||||
|
if thread_mapping.userId != user.id:
|
||||||
|
logger.warning(
|
||||||
|
f"User {user.id} attempted to access thread {thread_id} "
|
||||||
|
f"belonging to user {thread_mapping.userId}"
|
||||||
|
)
|
||||||
|
raise PermissionError(
|
||||||
|
f"You do not have permission to access thread {thread_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Thread {thread_id} verified for user {user.id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def update_thread_name(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
user: User,
|
||||||
|
new_thread_name: str,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Update the name of an existing chat thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The unique identifier for the chat thread.
|
||||||
|
user: The current user.
|
||||||
|
new_thread_name: The new name to set for the thread.
|
||||||
|
session: The database session for updating.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If the thread does not belong to the user.
|
||||||
|
ValueError: If the thread does not exist.
|
||||||
|
"""
|
||||||
|
# Verify thread exists and belongs to user
|
||||||
|
await assure_thread_exists_and_belongs_to_user(
|
||||||
|
thread_id=thread_id,
|
||||||
|
user=user,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the thread name
|
||||||
|
stmt = (
|
||||||
|
update(UserThreadMapping)
|
||||||
|
.where(UserThreadMapping.threadId == thread_id)
|
||||||
|
.values(threadName=new_thread_name)
|
||||||
|
)
|
||||||
|
await session.execute(stmt)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Successfully updated thread {thread_id} name for user {user.id}")
|
||||||
|
|
||||||
|
|
||||||
async def post_message(
|
async def post_message(
|
||||||
*,
|
*,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,11 @@ from typing import Any, Dict, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
from modules.features.chatBot.database import get_async_db_session
|
||||||
|
from modules.features.chatBot.service import get_or_create_thread_for_user
|
||||||
from modules.datamodels.datamodelUam import User
|
from modules.datamodels.datamodelUam import User
|
||||||
from modules.datamodels.datamodelChatbot import (
|
from modules.datamodels.datamodelChatbot import (
|
||||||
ChatMessageRequest,
|
ChatMessageRequest,
|
||||||
|
|
@ -38,6 +42,7 @@ async def post_chat_message_stream(
|
||||||
request: Request,
|
request: Request,
|
||||||
message_request: ChatMessageRequest,
|
message_request: ChatMessageRequest,
|
||||||
currentUser: User = Depends(getCurrentUser),
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
session: AsyncSession = Depends(get_async_db_session),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
Post a message to a chat thread with streaming progress updates.
|
Post a message to a chat thread with streaming progress updates.
|
||||||
|
|
@ -46,12 +51,12 @@ async def post_chat_message_stream(
|
||||||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# TODO: Add helper here, if no thread id is provided, add entry in mapping table.
|
# Get or create thread using helper function
|
||||||
|
thread_id = await get_or_create_thread_for_user(
|
||||||
# TODO: If not provided, create new thread in LangGraph's checkpointer, and add it to mapping table.
|
thread_id=message_request.thread_id,
|
||||||
|
user=currentUser,
|
||||||
# Generate or use existing thread_id
|
session=session,
|
||||||
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
||||||
|
|
@ -87,6 +92,7 @@ async def post_chat_message(
|
||||||
request: Request,
|
request: Request,
|
||||||
message_request: ChatMessageRequest,
|
message_request: ChatMessageRequest,
|
||||||
currentUser: User = Depends(getCurrentUser),
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
session: AsyncSession = Depends(get_async_db_session),
|
||||||
) -> ChatMessageResponse:
|
) -> ChatMessageResponse:
|
||||||
"""
|
"""
|
||||||
Post a message to a chat thread and get assistant response (non-streaming).
|
Post a message to a chat thread and get assistant response (non-streaming).
|
||||||
|
|
@ -95,8 +101,12 @@ async def post_chat_message(
|
||||||
For streaming updates, use the /message/stream endpoint instead.
|
For streaming updates, use the /message/stream endpoint instead.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Generate or use existing thread_id
|
# Get or create thread using helper function
|
||||||
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
|
thread_id = await get_or_create_thread_for_user(
|
||||||
|
thread_id=message_request.thread_id,
|
||||||
|
user=currentUser,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue