feat: implement rename, delete thread endpoints
This commit is contained in:
parent
a0a87e2e3e
commit
bfc07ee0b1
3 changed files with 175 additions and 10 deletions
|
|
@ -55,6 +55,12 @@ class ThreadDetail(BaseModel, ModelMixin):
|
|||
)
|
||||
|
||||
|
||||
class RenameThreadRequest(BaseModel, ModelMixin):
|
||||
"""Request model for renaming a thread"""
|
||||
|
||||
new_name: str = Field(..., description="New name for the thread")
|
||||
|
||||
|
||||
class DeleteResponse(BaseModel, ModelMixin):
|
||||
"""Response model for delete operations"""
|
||||
|
||||
|
|
@ -120,6 +126,14 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"RenameThreadRequest",
|
||||
{"en": "Rename Thread Request", "fr": "Demande de renommage de fil"},
|
||||
{
|
||||
"new_name": {"en": "New Name", "fr": "Nouveau nom"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"DeleteResponse",
|
||||
{"en": "Delete Response", "fr": "Réponse de suppression"},
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
|
||||
|
|
@ -612,3 +612,91 @@ async def get_thread_detail_for_user(
|
|||
date_updated=thread_mapping.date_updated.timestamp(),
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
async def delete_thread_for_user(
|
||||
*,
|
||||
thread_id: str,
|
||||
user: User,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
"""Delete a chat thread for a user from both LangGraph and the database.
|
||||
|
||||
Args:
|
||||
thread_id: The unique identifier for the chat thread.
|
||||
user: The current user.
|
||||
session: The database session for deleting.
|
||||
|
||||
Raises:
|
||||
PermissionError: If the thread does not belong to the user.
|
||||
ValueError: If the thread does not exist.
|
||||
"""
|
||||
logger.info(f"Deleting thread {thread_id} for user {user.id}")
|
||||
|
||||
# Verify thread exists and belongs to user
|
||||
await assure_thread_exists_and_belongs_to_user(
|
||||
thread_id=thread_id, user=user, session=session
|
||||
)
|
||||
|
||||
# Build the chatbot app to access the checkpointer
|
||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||
if not tool_ids:
|
||||
raise ValueError("User does not have permission to use any chatbot tools")
|
||||
|
||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||
|
||||
# Get tools from registry
|
||||
registry = get_registry()
|
||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||
|
||||
# Get model and checkpointer
|
||||
model = get_langchain_model(model_name=model_name)
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
# Get context window size from config
|
||||
context_window_size = int(
|
||||
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
|
||||
)
|
||||
|
||||
# Create chatbot instance
|
||||
chatbot = await Chatbot.create(
|
||||
model=model,
|
||||
memory=checkpointer,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
|
||||
# Delete from LangGraph checkpointer
|
||||
try:
|
||||
await chatbot.app.checkpointer.adelete_thread(thread_id)
|
||||
logger.info(f"Deleted thread {thread_id} from LangGraph checkpointer")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete thread {thread_id} from LangGraph: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to delete thread from LangGraph: {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
|
||||
# Delete from database
|
||||
stmt = delete(UserThreadMapping).where(
|
||||
UserThreadMapping.thread_id == thread_id,
|
||||
UserThreadMapping.user_id == user.id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
# Check if any rows were deleted
|
||||
if result.rowcount == 0:
|
||||
logger.warning(
|
||||
f"Failed to delete thread {thread_id} from database for user {user.id} - "
|
||||
"thread does not exist or does not belong to user"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Thread {thread_id} does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully deleted thread {thread_id} for user {user.id}")
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from modules.datamodels.datamodelChatbot import (
|
|||
ThreadSummary,
|
||||
ThreadListResponse,
|
||||
ThreadDetail,
|
||||
RenameThreadRequest,
|
||||
DeleteResponse,
|
||||
)
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
|
|
@ -216,29 +217,91 @@ async def get_thread_by_id(
|
|||
)
|
||||
|
||||
|
||||
@router.patch("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def rename_thread(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
rename_request: RenameThreadRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Rename a chat thread.
|
||||
"""
|
||||
try:
|
||||
await chat_service.update_thread_name(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
new_thread_name=rename_request.new_name,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'"
|
||||
)
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully renamed",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_thread(
|
||||
*, request: Request, thread_id: str, currentUser: User = Depends(getCurrentUser)
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Delete a chat thread and all its associated data.
|
||||
|
||||
This endpoint will later delete from LangGraph's PostgreSQL checkpointer.
|
||||
Delete a chat thread and all its associated data from both LangGraph and database.
|
||||
"""
|
||||
try:
|
||||
# In production, this will:
|
||||
# 1. Verify the thread belongs to the current user
|
||||
# 2. Delete the thread from LangGraph's checkpointer
|
||||
# 3. Clean up any associated data
|
||||
await chat_service.delete_thread_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} deleted thread {thread_id}")
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully deleted (dummy response)",
|
||||
message=f"Thread {thread_id} successfully deleted",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
|
|
|
|||
Loading…
Reference in a new issue