feat: implement get thread details endpoint
This commit is contained in:
parent
ed3920f9f9
commit
85503fc669
3 changed files with 154 additions and 38 deletions
|
|
@ -49,7 +49,8 @@ class ThreadDetail(BaseModel, ModelMixin):
|
||||||
"""Detailed view of a single thread"""
|
"""Detailed view of a single thread"""
|
||||||
|
|
||||||
thread_id: str = Field(..., description="Thread ID")
|
thread_id: str = Field(..., description="Thread ID")
|
||||||
created_at: float = Field(..., description="Thread creation timestamp")
|
date_created: float = Field(..., description="Thread creation timestamp")
|
||||||
|
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||||
messages: List[MessageItem] = Field(
|
messages: List[MessageItem] = Field(
|
||||||
..., description="All messages in chronological order"
|
..., description="All messages in chronological order"
|
||||||
)
|
)
|
||||||
|
|
@ -115,7 +116,8 @@ register_model_labels(
|
||||||
{"en": "Thread Detail", "fr": "Détail du fil"},
|
{"en": "Thread Detail", "fr": "Détail du fil"},
|
||||||
{
|
{
|
||||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||||
"created_at": {"en": "Created At", "fr": "Créé le"},
|
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||||
|
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||||
"messages": {"en": "Messages", "fr": "Messages"},
|
"messages": {"en": "Messages", "fr": "Messages"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,11 @@ from modules.datamodels.datamodelChatbot import (
|
||||||
MessageItem,
|
MessageItem,
|
||||||
ChatMessageResponse,
|
ChatMessageResponse,
|
||||||
ThreadSummary,
|
ThreadSummary,
|
||||||
|
ThreadDetail,
|
||||||
)
|
)
|
||||||
from modules.datamodels.datamodelUam import User
|
from modules.datamodels.datamodelUam import User
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -491,3 +492,126 @@ async def post_message_stream(
|
||||||
)
|
)
|
||||||
+ "\n\n"
|
+ "\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_thread_messages_from_langgraph(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
app,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""Retrieve and format messages from LangGraph checkpointer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The unique identifier for the chat thread.
|
||||||
|
app: The compiled LangGraph app with checkpointer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dicts with role, content, and timestamp.
|
||||||
|
"""
|
||||||
|
ROLE_MAP = {"human": "user", "ai": "assistant"}
|
||||||
|
|
||||||
|
cfg = {"configurable": {"thread_id": thread_id}}
|
||||||
|
state = await app.aget_state(cfg)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for msg in state.values.get("messages", []):
|
||||||
|
# Skip system and tool messages - only include user and assistant
|
||||||
|
if msg.type not in ["human", "ai"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert content to string if needed
|
||||||
|
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": ROLE_MAP.get(msg.type, msg.type),
|
||||||
|
"content": content,
|
||||||
|
"timestamp": 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
async def get_thread_detail_for_user(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
user: User,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> ThreadDetail:
|
||||||
|
"""Get detailed thread information with message history from LangGraph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The unique identifier for the chat thread.
|
||||||
|
user: The current user.
|
||||||
|
session: The database session for querying.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ThreadDetail object with thread metadata and message history.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If the thread does not belong to the user.
|
||||||
|
ValueError: If the thread does not exist.
|
||||||
|
"""
|
||||||
|
logger.info(f"Getting thread detail for thread {thread_id} for user {user.id}")
|
||||||
|
|
||||||
|
# Verify thread exists and belongs to user
|
||||||
|
await assure_thread_exists_and_belongs_to_user(
|
||||||
|
thread_id=thread_id, user=user, session=session
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get thread metadata from database
|
||||||
|
stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
thread_mapping = result.scalar_one()
|
||||||
|
|
||||||
|
# Build the chatbot app to access LangGraph state
|
||||||
|
# Use same approach as post_message for consistency
|
||||||
|
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||||
|
if not tool_ids:
|
||||||
|
raise ValueError("User does not have permission to use any chatbot tools")
|
||||||
|
|
||||||
|
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||||
|
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||||
|
|
||||||
|
# Get tools from registry
|
||||||
|
registry = get_registry()
|
||||||
|
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||||
|
|
||||||
|
# Get model and checkpointer
|
||||||
|
model = get_langchain_model(model_name=model_name)
|
||||||
|
checkpointer = get_checkpointer()
|
||||||
|
|
||||||
|
# Get context window size from config
|
||||||
|
context_window_size = int(
|
||||||
|
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create chatbot instance
|
||||||
|
chatbot = await Chatbot.create(
|
||||||
|
model=model,
|
||||||
|
memory=checkpointer,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
context_window_size=context_window_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get messages from LangGraph checkpointer
|
||||||
|
message_dicts = await get_thread_messages_from_langgraph(
|
||||||
|
thread_id=thread_id, app=chatbot.app
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to MessageItem objects
|
||||||
|
messages = [MessageItem(**m) for m in message_dicts]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Retrieved thread {thread_id} with {len(messages)} messages for user {user.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return ThreadDetail
|
||||||
|
return ThreadDetail(
|
||||||
|
thread_id=thread_id,
|
||||||
|
date_created=thread_mapping.date_created.timestamp(),
|
||||||
|
date_updated=thread_mapping.date_updated.timestamp(),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -174,47 +174,37 @@ async def get_all_threads(
|
||||||
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
|
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
async def get_thread_by_id(
|
async def get_thread_by_id(
|
||||||
*, request: Request, thread_id: str, currentUser: User = Depends(getCurrentUser)
|
*,
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
session: AsyncSession = Depends(get_async_db_session),
|
||||||
) -> ThreadDetail:
|
) -> ThreadDetail:
|
||||||
"""
|
"""
|
||||||
Get a specific chat thread with all its messages.
|
Get a specific chat thread with all its messages from LangGraph checkpointer.
|
||||||
|
|
||||||
This endpoint will later fetch from LangGraph's PostgreSQL checkpointer.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Return dummy thread detail
|
thread_detail = await chat_service.get_thread_detail_for_user(
|
||||||
# In production, this will query LangGraph's checkpointer for the specific thread
|
thread_id=thread_id,
|
||||||
current_time = datetime.now().timestamp()
|
user=currentUser,
|
||||||
|
session=session,
|
||||||
dummy_messages = [
|
|
||||||
MessageItem(
|
|
||||||
role="user",
|
|
||||||
content="Hello! I need help with Python.",
|
|
||||||
timestamp=current_time - 120,
|
|
||||||
),
|
|
||||||
MessageItem(
|
|
||||||
role="assistant",
|
|
||||||
content="Hello! I'd be happy to help you with Python. What would you like to know?",
|
|
||||||
timestamp=current_time - 119,
|
|
||||||
),
|
|
||||||
MessageItem(
|
|
||||||
role="user",
|
|
||||||
content="How do I use list comprehensions?",
|
|
||||||
timestamp=current_time - 60,
|
|
||||||
),
|
|
||||||
MessageItem(
|
|
||||||
role="assistant",
|
|
||||||
content="List comprehensions are a concise way to create lists. Here's an example: [x**2 for x in range(10)]",
|
|
||||||
timestamp=current_time - 59,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
|
||||||
|
|
||||||
return ThreadDetail(
|
|
||||||
thread_id=thread_id, created_at=current_time - 120, messages=dummy_messages
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
||||||
|
return thread_detail
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Thread not found: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(e) or "Thread not found",
|
||||||
|
)
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(e) or "Permission denied",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
|
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue