feat: implement get thread details endpoint
This commit is contained in:
parent
ed3920f9f9
commit
85503fc669
3 changed files with 154 additions and 38 deletions
|
|
@ -49,7 +49,8 @@ class ThreadDetail(BaseModel, ModelMixin):
|
|||
"""Detailed view of a single thread"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
created_at: float = Field(..., description="Thread creation timestamp")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
messages: List[MessageItem] = Field(
|
||||
..., description="All messages in chronological order"
|
||||
)
|
||||
|
|
@ -115,7 +116,8 @@ register_model_labels(
|
|||
{"en": "Thread Detail", "fr": "Détail du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"created_at": {"en": "Created At", "fr": "Créé le"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ from modules.datamodels.datamodelChatbot import (
|
|||
MessageItem,
|
||||
ChatMessageResponse,
|
||||
ThreadSummary,
|
||||
ThreadDetail,
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -491,3 +492,126 @@ async def post_message_stream(
|
|||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
|
||||
|
||||
async def get_thread_messages_from_langgraph(
|
||||
*,
|
||||
thread_id: str,
|
||||
app,
|
||||
) -> List[dict]:
|
||||
"""Retrieve and format messages from LangGraph checkpointer.
|
||||
|
||||
Args:
|
||||
thread_id: The unique identifier for the chat thread.
|
||||
app: The compiled LangGraph app with checkpointer.
|
||||
|
||||
Returns:
|
||||
List of message dicts with role, content, and timestamp.
|
||||
"""
|
||||
ROLE_MAP = {"human": "user", "ai": "assistant"}
|
||||
|
||||
cfg = {"configurable": {"thread_id": thread_id}}
|
||||
state = await app.aget_state(cfg)
|
||||
|
||||
messages = []
|
||||
for msg in state.values.get("messages", []):
|
||||
# Skip system and tool messages - only include user and assistant
|
||||
if msg.type not in ["human", "ai"]:
|
||||
continue
|
||||
|
||||
# Convert content to string if needed
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": ROLE_MAP.get(msg.type, msg.type),
|
||||
"content": content,
|
||||
"timestamp": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def get_thread_detail_for_user(
|
||||
*,
|
||||
thread_id: str,
|
||||
user: User,
|
||||
session: AsyncSession,
|
||||
) -> ThreadDetail:
|
||||
"""Get detailed thread information with message history from LangGraph.
|
||||
|
||||
Args:
|
||||
thread_id: The unique identifier for the chat thread.
|
||||
user: The current user.
|
||||
session: The database session for querying.
|
||||
|
||||
Returns:
|
||||
ThreadDetail object with thread metadata and message history.
|
||||
|
||||
Raises:
|
||||
PermissionError: If the thread does not belong to the user.
|
||||
ValueError: If the thread does not exist.
|
||||
"""
|
||||
logger.info(f"Getting thread detail for thread {thread_id} for user {user.id}")
|
||||
|
||||
# Verify thread exists and belongs to user
|
||||
await assure_thread_exists_and_belongs_to_user(
|
||||
thread_id=thread_id, user=user, session=session
|
||||
)
|
||||
|
||||
# Get thread metadata from database
|
||||
stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id)
|
||||
result = await session.execute(stmt)
|
||||
thread_mapping = result.scalar_one()
|
||||
|
||||
# Build the chatbot app to access LangGraph state
|
||||
# Use same approach as post_message for consistency
|
||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||
if not tool_ids:
|
||||
raise ValueError("User does not have permission to use any chatbot tools")
|
||||
|
||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||
|
||||
# Get tools from registry
|
||||
registry = get_registry()
|
||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||
|
||||
# Get model and checkpointer
|
||||
model = get_langchain_model(model_name=model_name)
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
# Get context window size from config
|
||||
context_window_size = int(
|
||||
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
|
||||
)
|
||||
|
||||
# Create chatbot instance
|
||||
chatbot = await Chatbot.create(
|
||||
model=model,
|
||||
memory=checkpointer,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
|
||||
# Get messages from LangGraph checkpointer
|
||||
message_dicts = await get_thread_messages_from_langgraph(
|
||||
thread_id=thread_id, app=chatbot.app
|
||||
)
|
||||
|
||||
# Convert to MessageItem objects
|
||||
messages = [MessageItem(**m) for m in message_dicts]
|
||||
|
||||
logger.info(
|
||||
f"Retrieved thread {thread_id} with {len(messages)} messages for user {user.id}"
|
||||
)
|
||||
|
||||
# Return ThreadDetail
|
||||
return ThreadDetail(
|
||||
thread_id=thread_id,
|
||||
date_created=thread_mapping.date_created.timestamp(),
|
||||
date_updated=thread_mapping.date_updated.timestamp(),
|
||||
messages=messages,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -174,47 +174,37 @@ async def get_all_threads(
|
|||
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_thread_by_id(
|
||||
*, request: Request, thread_id: str, currentUser: User = Depends(getCurrentUser)
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadDetail:
|
||||
"""
|
||||
Get a specific chat thread with all its messages.
|
||||
|
||||
This endpoint will later fetch from LangGraph's PostgreSQL checkpointer.
|
||||
Get a specific chat thread with all its messages from LangGraph checkpointer.
|
||||
"""
|
||||
try:
|
||||
# Return dummy thread detail
|
||||
# In production, this will query LangGraph's checkpointer for the specific thread
|
||||
current_time = datetime.now().timestamp()
|
||||
|
||||
dummy_messages = [
|
||||
MessageItem(
|
||||
role="user",
|
||||
content="Hello! I need help with Python.",
|
||||
timestamp=current_time - 120,
|
||||
),
|
||||
MessageItem(
|
||||
role="assistant",
|
||||
content="Hello! I'd be happy to help you with Python. What would you like to know?",
|
||||
timestamp=current_time - 119,
|
||||
),
|
||||
MessageItem(
|
||||
role="user",
|
||||
content="How do I use list comprehensions?",
|
||||
timestamp=current_time - 60,
|
||||
),
|
||||
MessageItem(
|
||||
role="assistant",
|
||||
content="List comprehensions are a concise way to create lists. Here's an example: [x**2 for x in range(10)]",
|
||||
timestamp=current_time - 59,
|
||||
),
|
||||
]
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
||||
|
||||
return ThreadDetail(
|
||||
thread_id=thread_id, created_at=current_time - 120, messages=dummy_messages
|
||||
thread_detail = await chat_service.get_thread_detail_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
||||
return thread_detail
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
|
|
|
|||
Loading…
Reference in a new issue