feat: implement get threads endpoint
This commit is contained in:
parent
1143e181e8
commit
a08bd3ef1d
3 changed files with 62 additions and 34 deletions
|
|
@ -34,9 +34,9 @@ class ThreadSummary(BaseModel, ModelMixin):
|
|||
"""Summary of a chat thread for list view"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
created_at: float = Field(..., description="Thread creation timestamp")
|
||||
last_message: str = Field(..., description="Last message content")
|
||||
message_count: int = Field(..., description="Total number of messages")
|
||||
thread_name: str = Field(..., description="Thread name")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
|
||||
|
||||
class ThreadListResponse(BaseModel, ModelMixin):
|
||||
|
|
@ -96,9 +96,9 @@ register_model_labels(
|
|||
{"en": "Thread Summary", "fr": "Résumé du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"created_at": {"en": "Created At", "fr": "Créé le"},
|
||||
"last_message": {"en": "Last Message", "fr": "Dernier message"},
|
||||
"message_count": {"en": "Message Count", "fr": "Nombre de messages"},
|
||||
"thread_name": {"en": "Thread Name", "fr": "Nom du fil"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,11 @@ from modules.features.chatBot.utils.checkpointer import get_checkpointer
|
|||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
from modules.features.chatBot.utils import permissions
|
||||
from modules.features.chatBot.database import UserThreadMapping
|
||||
from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse
|
||||
from modules.datamodels.datamodelChatbot import (
|
||||
MessageItem,
|
||||
ChatMessageResponse,
|
||||
ThreadSummary,
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
|
@ -22,6 +26,47 @@ from modules.shared.configuration import APP_CONFIG
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_all_threads_for_user(
|
||||
*,
|
||||
user: User,
|
||||
session: AsyncSession,
|
||||
) -> List[ThreadSummary]:
|
||||
"""Get all chat threads for a user.
|
||||
|
||||
Args:
|
||||
user: The current user.
|
||||
session: The database session for querying.
|
||||
|
||||
Returns:
|
||||
List of ThreadSummary objects sorted by date_updated (newest first).
|
||||
Returns empty list if no threads found.
|
||||
"""
|
||||
logger.info(f"Fetching all threads for user {user.id}")
|
||||
|
||||
# Query all threads for this user, ordered by date_updated descending
|
||||
stmt = (
|
||||
select(UserThreadMapping)
|
||||
.where(UserThreadMapping.userId == user.id)
|
||||
.order_by(UserThreadMapping.date_updated.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
thread_mappings = result.scalars().all()
|
||||
|
||||
# Convert to ThreadSummary objects
|
||||
threads = []
|
||||
for mapping in thread_mappings:
|
||||
thread_summary = ThreadSummary(
|
||||
thread_id=mapping.threadId,
|
||||
thread_name=mapping.threadName,
|
||||
date_created=mapping.date_created.timestamp(),
|
||||
date_updated=mapping.date_updated.timestamp(),
|
||||
)
|
||||
threads.append(thread_summary)
|
||||
|
||||
logger.info(f"Found {len(threads)} threads for user {user.id}")
|
||||
return threads
|
||||
|
||||
|
||||
async def save_thread_for_user(
|
||||
*,
|
||||
thread_id: str,
|
||||
|
|
|
|||
|
|
@ -143,40 +143,23 @@ async def post_chat_message(
|
|||
@router.get("/threads", response_model=ThreadListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_threads(
|
||||
*, request: Request, currentUser: User = Depends(getCurrentUser)
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadListResponse:
|
||||
"""
|
||||
Get all chat threads for the current user.
|
||||
|
||||
This endpoint will later fetch from LangGraph's PostgreSQL checkpointer.
|
||||
"""
|
||||
try:
|
||||
# Return dummy thread data
|
||||
# In production, this will query LangGraph's checkpointer database
|
||||
dummy_threads = [
|
||||
ThreadSummary(
|
||||
thread_id="thread_001",
|
||||
created_at=datetime.now().timestamp() - 86400, # 1 day ago
|
||||
last_message="Hello, how can I help you?",
|
||||
message_count=4,
|
||||
),
|
||||
ThreadSummary(
|
||||
thread_id="thread_002",
|
||||
created_at=datetime.now().timestamp() - 3600, # 1 hour ago
|
||||
last_message="Thank you for your help!",
|
||||
message_count=8,
|
||||
),
|
||||
ThreadSummary(
|
||||
thread_id="thread_003",
|
||||
created_at=datetime.now().timestamp() - 300, # 5 minutes ago
|
||||
last_message="Can you explain this concept?",
|
||||
message_count=2,
|
||||
),
|
||||
]
|
||||
# Get all threads for the current user
|
||||
threads = await chat_service.get_all_threads_for_user(
|
||||
user=currentUser, session=session
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved {len(dummy_threads)} threads")
|
||||
logger.info(f"User {currentUser.id} retrieved {len(threads)} threads")
|
||||
|
||||
return ThreadListResponse(threads=dummy_threads)
|
||||
return ThreadListResponse(threads=threads)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
|
|||
Loading…
Reference in a new issue