feat: add and handle chatbot thread date_created and date_modified

This commit is contained in:
Christopher Gondek 2025-10-08 15:30:14 +02:00
parent b50dcc6c0f
commit 1143e181e8
3 changed files with 102 additions and 17 deletions

View file

@ -1,10 +1,11 @@
from typing import AsyncIterator
import uuid
from datetime import datetime, timezone
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Uuid
from sqlalchemy import String, Uuid, DateTime
class Base(DeclarativeBase):
@ -28,6 +29,16 @@ class UserThreadMapping(Base):
userId: Mapped[str] = mapped_column(String(255), nullable=False)
threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
threadName: Mapped[str] = mapped_column(String(255), nullable=False)
date_created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
date_updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# Dependency that pulls the sessionmaker off app.state

View file

@ -2,6 +2,7 @@
import json
import logging
from datetime import datetime, timezone
from typing import AsyncIterator, List, Optional
from sqlalchemy import select, update
@ -27,7 +28,6 @@ async def save_thread_for_user(
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
title: str = "New Chat",
) -> None:
"""Save a new chat thread mapping for the user.
@ -36,7 +36,6 @@ async def save_thread_for_user(
user: The current user.
session: The database session for saving.
thread_name: The name of the chat thread. Defaults to "New Chat".
title: Optional title for the chat (currently unused).
"""
logger.info(f"Saving new thread {thread_id} for user {user.id}")
@ -59,6 +58,7 @@ async def get_or_create_thread_for_user(
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
refresh_date_updated: bool = False,
) -> str:
"""Get an existing thread or create a new one for the user.
@ -70,6 +70,7 @@ async def get_or_create_thread_for_user(
user: The current user.
session: The database session for querying/saving.
thread_name: The name for the thread if creating new. Defaults to "New Chat".
refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False.
Returns:
The thread_id to use (either the provided one or newly created).
@ -84,6 +85,13 @@ async def get_or_create_thread_for_user(
thread_id=thread_id, user=user, session=session
)
logger.info(f"Using existing thread {thread_id} for user {user.id}")
# Refresh date_updated if requested
if refresh_date_updated:
await refresh_thread_date_updated(
thread_id=thread_id, user=user, session=session
)
return thread_id
else:
# Generate new thread_id if the user did not provide a thread_id
@ -148,6 +156,10 @@ async def update_thread_name(
) -> None:
"""Update the name of an existing chat thread.
This function performs security checks by including both threadId and userId
in the WHERE clause of the UPDATE query, ensuring users can only update
threads that belong to them. No separate permission check is needed.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
@ -155,32 +167,88 @@ async def update_thread_name(
session: The database session for updating.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
ValueError: If the thread does not exist or does not belong to the user.
"""
# Verify thread exists and belongs to user
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id,
user=user,
session=session,
)
logger.info(
f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}"
)
# Update the thread name
# Update the thread name and date_updated
# Security check: WHERE clause includes both threadId AND userId
stmt = (
update(UserThreadMapping)
.where(UserThreadMapping.threadId == thread_id)
.values(threadName=new_thread_name)
.where(
UserThreadMapping.threadId == thread_id,
UserThreadMapping.userId == user.id,
)
.values(threadName=new_thread_name, date_updated=datetime.now(timezone.utc))
)
await session.execute(stmt)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were affected
if result.rowcount == 0:
logger.warning(
f"Failed to update thread {thread_id} for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(f"Successfully updated thread {thread_id} name for user {user.id}")
async def refresh_thread_date_updated(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Refresh the date_updated timestamp for an existing chat thread.
This function performs security checks by including both threadId and userId
in the WHERE clause of the UPDATE query, ensuring users can only update
threads that belong to them. No separate permission check is needed.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for updating.
Raises:
ValueError: If the thread does not exist or does not belong to the user.
"""
logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}")
# Update the date_updated timestamp
# Security check: WHERE clause includes both threadId AND userId
stmt = (
update(UserThreadMapping)
.where(
UserThreadMapping.threadId == thread_id,
UserThreadMapping.userId == user.id,
)
.values(date_updated=datetime.now(timezone.utc))
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were affected
if result.rowcount == 0:
logger.warning(
f"Failed to refresh thread {thread_id} for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(
f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}"
)
async def post_message(
*,
thread_id: str,

View file

@ -9,7 +9,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from modules.features.chatBot.database import get_async_db_session
from modules.features.chatBot.service import get_or_create_thread_for_user
from modules.features.chatBot.service import (
get_or_create_thread_for_user,
)
from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelChatbot import (
ChatMessageRequest,
@ -56,6 +58,8 @@ async def post_chat_message_stream(
thread_id=message_request.thread_id,
user=currentUser,
session=session,
thread_name=message_request.message[:100],
refresh_date_updated=True,
)
logger.info(
@ -106,6 +110,8 @@ async def post_chat_message(
thread_id=message_request.thread_id,
user=currentUser,
session=session,
thread_name=message_request.message[:100],
refresh_date_updated=True,
)
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")