From 1143e181e851a6fba046a263af3fc79a06a74005 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 15:30:14 +0200 Subject: [PATCH] feat: add and handle chatbot thread date_created and date_modified --- modules/features/chatBot/database.py | 13 +++- modules/features/chatBot/service.py | 98 +++++++++++++++++++++++----- modules/routes/routeChatbot.py | 8 ++- 3 files changed, 102 insertions(+), 17 deletions(-) diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py index ba67a28b..50dc9bba 100644 --- a/modules/features/chatBot/database.py +++ b/modules/features/chatBot/database.py @@ -1,10 +1,11 @@ from typing import AsyncIterator import uuid +from datetime import datetime, timezone from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import String, Uuid +from sqlalchemy import String, Uuid, DateTime class Base(DeclarativeBase): @@ -28,6 +29,16 @@ class UserThreadMapping(Base): userId: Mapped[str] = mapped_column(String(255), nullable=False) threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) threadName: Mapped[str] = mapped_column(String(255), nullable=False) + date_created: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + date_updated: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) # Dependency that pulls the sessionmaker off app.state diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 86442ae6..05df7f0f 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -2,6 +2,7 @@ import json import logging +from datetime import datetime, timezone from typing import AsyncIterator, List, Optional from sqlalchemy import select, update @@ -27,7 +28,6 @@ async def save_thread_for_user( user: User, session: AsyncSession, thread_name: str = "New Chat", - title: str = "New Chat", ) -> None: """Save a new chat thread mapping for the user. @@ -36,7 +36,6 @@ async def save_thread_for_user( user: The current user. session: The database session for saving. thread_name: The name of the chat thread. Defaults to "New Chat". - title: Optional title for the chat (currently unused). """ logger.info(f"Saving new thread {thread_id} for user {user.id}") @@ -59,6 +58,7 @@ async def get_or_create_thread_for_user( user: User, session: AsyncSession, thread_name: str = "New Chat", + refresh_date_updated: bool = False, ) -> str: """Get an existing thread or create a new one for the user. @@ -70,6 +70,7 @@ async def get_or_create_thread_for_user( user: The current user. session: The database session for querying/saving. thread_name: The name for the thread if creating new. Defaults to "New Chat". + refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False. Returns: The thread_id to use (either the provided one or newly created). @@ -84,6 +85,13 @@ async def get_or_create_thread_for_user( thread_id=thread_id, user=user, session=session ) logger.info(f"Using existing thread {thread_id} for user {user.id}") + + # Refresh date_updated if requested + if refresh_date_updated: + await refresh_thread_date_updated( + thread_id=thread_id, user=user, session=session + ) + return thread_id else: # Generate new thread_id if the user did not provide a thread_id @@ -148,6 +156,10 @@ async def update_thread_name( ) -> None: """Update the name of an existing chat thread. + This function performs security checks by including both threadId and userId + in the WHERE clause of the UPDATE query, ensuring users can only update + threads that belong to them. No separate permission check is needed. + Args: thread_id: The unique identifier for the chat thread. user: The current user. @@ -155,32 +167,88 @@ async def update_thread_name( session: The database session for updating. Raises: - PermissionError: If the thread does not belong to the user. - ValueError: If the thread does not exist. + ValueError: If the thread does not exist or does not belong to the user. """ - # Verify thread exists and belongs to user - await assure_thread_exists_and_belongs_to_user( - thread_id=thread_id, - user=user, - session=session, - ) - logger.info( f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}" ) - # Update the thread name + # Update the thread name and date_updated + # Security check: WHERE clause includes both threadId AND userId stmt = ( update(UserThreadMapping) - .where(UserThreadMapping.threadId == thread_id) - .values(threadName=new_thread_name) + .where( + UserThreadMapping.threadId == thread_id, + UserThreadMapping.userId == user.id, + ) + .values(threadName=new_thread_name, date_updated=datetime.now(timezone.utc)) ) - await session.execute(stmt) + result = await session.execute(stmt) await session.commit() + # Check if any rows were affected + if result.rowcount == 0: + logger.warning( + f"Failed to update thread {thread_id} for user {user.id} - " + "thread does not exist or does not belong to user" + ) + raise ValueError( + f"Thread {thread_id} does not exist or you do not have permission to access it" + ) + logger.info(f"Successfully updated thread {thread_id} name for user {user.id}") +async def refresh_thread_date_updated( + *, + thread_id: str, + user: User, + session: AsyncSession, +) -> None: + """Refresh the date_updated timestamp for an existing chat thread. + + This function performs security checks by including both threadId and userId + in the WHERE clause of the UPDATE query, ensuring users can only update + threads that belong to them. No separate permission check is needed. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for updating. + + Raises: + ValueError: If the thread does not exist or does not belong to the user. + """ + logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}") + + # Update the date_updated timestamp + # Security check: WHERE clause includes both threadId AND userId + stmt = ( + update(UserThreadMapping) + .where( + UserThreadMapping.threadId == thread_id, + UserThreadMapping.userId == user.id, + ) + .values(date_updated=datetime.now(timezone.utc)) + ) + result = await session.execute(stmt) + await session.commit() + + # Check if any rows were affected + if result.rowcount == 0: + logger.warning( + f"Failed to refresh thread {thread_id} for user {user.id} - " + "thread does not exist or does not belong to user" + ) + raise ValueError( + f"Thread {thread_id} does not exist or you do not have permission to access it" + ) + + logger.info( + f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}" + ) + + async def post_message( *, thread_id: str, diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 8f3efd65..e14f38c6 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -9,7 +9,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from modules.features.chatBot.database import get_async_db_session -from modules.features.chatBot.service import get_or_create_thread_for_user +from modules.features.chatBot.service import ( + get_or_create_thread_for_user, +) from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelChatbot import ( ChatMessageRequest, @@ -56,6 +58,8 @@ async def post_chat_message_stream( thread_id=message_request.thread_id, user=currentUser, session=session, + thread_name=message_request.message[:100], + refresh_date_updated=True, ) logger.info( @@ -106,6 +110,8 @@ async def post_chat_message( thread_id=message_request.thread_id, user=currentUser, session=session, + thread_name=message_request.message[:100], + refresh_date_updated=True, ) logger.info(f"User {currentUser.id} posted message to thread {thread_id}")