feat: add and handle chatbot thread date_created and date_modified
This commit is contained in:
parent
b50dcc6c0f
commit
1143e181e8
3 changed files with 102 additions and 17 deletions
|
|
@ -1,10 +1,11 @@
|
|||
from typing import AsyncIterator
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy import String, Uuid
|
||||
from sqlalchemy import String, Uuid, DateTime
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
|
|
@ -28,6 +29,16 @@ class UserThreadMapping(Base):
|
|||
userId: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
threadName: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
date_created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# Dependency that pulls the sessionmaker off app.state
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
from sqlalchemy import select, update
|
||||
|
|
@ -27,7 +28,6 @@ async def save_thread_for_user(
|
|||
user: User,
|
||||
session: AsyncSession,
|
||||
thread_name: str = "New Chat",
|
||||
title: str = "New Chat",
|
||||
) -> None:
|
||||
"""Save a new chat thread mapping for the user.
|
||||
|
||||
|
|
@ -36,7 +36,6 @@ async def save_thread_for_user(
|
|||
user: The current user.
|
||||
session: The database session for saving.
|
||||
thread_name: The name of the chat thread. Defaults to "New Chat".
|
||||
title: Optional title for the chat (currently unused).
|
||||
"""
|
||||
logger.info(f"Saving new thread {thread_id} for user {user.id}")
|
||||
|
||||
|
|
@ -59,6 +58,7 @@ async def get_or_create_thread_for_user(
|
|||
user: User,
|
||||
session: AsyncSession,
|
||||
thread_name: str = "New Chat",
|
||||
refresh_date_updated: bool = False,
|
||||
) -> str:
|
||||
"""Get an existing thread or create a new one for the user.
|
||||
|
||||
|
|
@ -70,6 +70,7 @@ async def get_or_create_thread_for_user(
|
|||
user: The current user.
|
||||
session: The database session for querying/saving.
|
||||
thread_name: The name for the thread if creating new. Defaults to "New Chat".
|
||||
refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The thread_id to use (either the provided one or newly created).
|
||||
|
|
@ -84,6 +85,13 @@ async def get_or_create_thread_for_user(
|
|||
thread_id=thread_id, user=user, session=session
|
||||
)
|
||||
logger.info(f"Using existing thread {thread_id} for user {user.id}")
|
||||
|
||||
# Refresh date_updated if requested
|
||||
if refresh_date_updated:
|
||||
await refresh_thread_date_updated(
|
||||
thread_id=thread_id, user=user, session=session
|
||||
)
|
||||
|
||||
return thread_id
|
||||
else:
|
||||
# Generate new thread_id if the user did not provide a thread_id
|
||||
|
|
@ -148,6 +156,10 @@ async def update_thread_name(
|
|||
) -> None:
|
||||
"""Update the name of an existing chat thread.
|
||||
|
||||
This function performs security checks by including both threadId and userId
|
||||
in the WHERE clause of the UPDATE query, ensuring users can only update
|
||||
threads that belong to them. No separate permission check is needed.
|
||||
|
||||
Args:
|
||||
thread_id: The unique identifier for the chat thread.
|
||||
user: The current user.
|
||||
|
|
@ -155,32 +167,88 @@ async def update_thread_name(
|
|||
session: The database session for updating.
|
||||
|
||||
Raises:
|
||||
PermissionError: If the thread does not belong to the user.
|
||||
ValueError: If the thread does not exist.
|
||||
ValueError: If the thread does not exist or does not belong to the user.
|
||||
"""
|
||||
# Verify thread exists and belongs to user
|
||||
await assure_thread_exists_and_belongs_to_user(
|
||||
thread_id=thread_id,
|
||||
user=user,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}"
|
||||
)
|
||||
|
||||
# Update the thread name
|
||||
# Update the thread name and date_updated
|
||||
# Security check: WHERE clause includes both threadId AND userId
|
||||
stmt = (
|
||||
update(UserThreadMapping)
|
||||
.where(UserThreadMapping.threadId == thread_id)
|
||||
.values(threadName=new_thread_name)
|
||||
.where(
|
||||
UserThreadMapping.threadId == thread_id,
|
||||
UserThreadMapping.userId == user.id,
|
||||
)
|
||||
.values(threadName=new_thread_name, date_updated=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
# Check if any rows were affected
|
||||
if result.rowcount == 0:
|
||||
logger.warning(
|
||||
f"Failed to update thread {thread_id} for user {user.id} - "
|
||||
"thread does not exist or does not belong to user"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Thread {thread_id} does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully updated thread {thread_id} name for user {user.id}")
|
||||
|
||||
|
||||
async def refresh_thread_date_updated(
|
||||
*,
|
||||
thread_id: str,
|
||||
user: User,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
"""Refresh the date_updated timestamp for an existing chat thread.
|
||||
|
||||
This function performs security checks by including both threadId and userId
|
||||
in the WHERE clause of the UPDATE query, ensuring users can only update
|
||||
threads that belong to them. No separate permission check is needed.
|
||||
|
||||
Args:
|
||||
thread_id: The unique identifier for the chat thread.
|
||||
user: The current user.
|
||||
session: The database session for updating.
|
||||
|
||||
Raises:
|
||||
ValueError: If the thread does not exist or does not belong to the user.
|
||||
"""
|
||||
logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}")
|
||||
|
||||
# Update the date_updated timestamp
|
||||
# Security check: WHERE clause includes both threadId AND userId
|
||||
stmt = (
|
||||
update(UserThreadMapping)
|
||||
.where(
|
||||
UserThreadMapping.threadId == thread_id,
|
||||
UserThreadMapping.userId == user.id,
|
||||
)
|
||||
.values(date_updated=datetime.now(timezone.utc))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
# Check if any rows were affected
|
||||
if result.rowcount == 0:
|
||||
logger.warning(
|
||||
f"Failed to refresh thread {thread_id} for user {user.id} - "
|
||||
"thread does not exist or does not belong to user"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Thread {thread_id} does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}"
|
||||
)
|
||||
|
||||
|
||||
async def post_message(
|
||||
*,
|
||||
thread_id: str,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
|
||||
from modules.features.chatBot.database import get_async_db_session
|
||||
from modules.features.chatBot.service import get_or_create_thread_for_user
|
||||
from modules.features.chatBot.service import (
|
||||
get_or_create_thread_for_user,
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelChatbot import (
|
||||
ChatMessageRequest,
|
||||
|
|
@ -56,6 +58,8 @@ async def post_chat_message_stream(
|
|||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
@ -106,6 +110,8 @@ async def post_chat_message(
|
|||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue