"""Service layer for chatbot functionality.""" import json import logging from datetime import datetime, timezone from typing import AsyncIterator, List, Optional from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model from modules.features.chatBot.utils.checkpointer import get_checkpointer from modules.features.chatBot.utils.toolRegistry import get_registry from modules.features.chatBot.utils import permissions from modules.features.chatBot.database import UserThreadMapping from modules.datamodels.datamodelChatbot import ( MessageItem, ChatMessageResponse, ThreadSummary, ThreadDetail, ) from modules.datamodels.datamodelUam import User from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) async def get_all_threads_for_user( *, user: User, session: AsyncSession, ) -> List[ThreadSummary]: """Get all chat threads for a user. Args: user: The current user. session: The database session for querying. Returns: List of ThreadSummary objects sorted by date_updated (newest first). Returns empty list if no threads found. """ logger.info(f"Fetching all threads for user {user.id}") # Query all threads for this user, ordered by date_updated descending stmt = ( select(UserThreadMapping) .where(UserThreadMapping.user_id == user.id) .order_by(UserThreadMapping.date_updated.desc()) ) result = await session.execute(stmt) thread_mappings = result.scalars().all() # Convert to ThreadSummary objects threads = [] for mapping in thread_mappings: thread_summary = ThreadSummary( thread_id=mapping.thread_id, thread_name=mapping.thread_name, date_created=mapping.date_created.timestamp(), date_updated=mapping.date_updated.timestamp(), ) threads.append(thread_summary) logger.info(f"Found {len(threads)} threads for user {user.id}") return threads async def save_thread_for_user( *, thread_id: str, user: User, session: AsyncSession, thread_name: str = "New Chat", ) -> None: """Save a new chat thread mapping for the user. Args: thread_id: The unique identifier for the chat thread. user: The current user. session: The database session for saving. thread_name: The name of the chat thread. Defaults to "New Chat". """ logger.info(f"Saving new thread {thread_id} for user {user.id}") # Create new mapping entry new_mapping = UserThreadMapping( user_id=user.id, thread_id=thread_id, thread_name=thread_name, ) session.add(new_mapping) await session.commit() logger.info(f"Successfully saved thread {thread_id} for user {user.id}") async def get_or_create_thread_for_user( *, thread_id: Optional[str], user: User, session: AsyncSession, thread_name: str = "New Chat", refresh_date_updated: bool = False, ) -> str: """Get an existing thread or create a new one for the user. If thread_id is provided, verifies it exists and belongs to the user. If thread_id is None, generates a new thread_id and saves it. Args: thread_id: Optional thread identifier. If None, creates a new thread. user: The current user. session: The database session for querying/saving. thread_name: The name for the thread if creating new. Defaults to "New Chat". refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False. Returns: The thread_id to use (either the provided one or newly created). Raises: PermissionError: If the thread does not belong to the user. ValueError: If the provided thread_id does not exist. """ if thread_id: # If the user provided a thread_id, verify it exists and belongs to them await assure_thread_exists_and_belongs_to_user( thread_id=thread_id, user=user, session=session ) logger.info(f"Using existing thread {thread_id} for user {user.id}") # Refresh date_updated if requested if refresh_date_updated: await refresh_thread_date_updated( thread_id=thread_id, user=user, session=session ) return thread_id else: # Generate new thread_id if the user did not provide a thread_id import uuid new_thread_id = f"thread_{uuid.uuid4()}" await save_thread_for_user( thread_id=new_thread_id, user=user, session=session, thread_name=thread_name, ) logger.info(f"Created new thread {new_thread_id} for user {user.id}") return new_thread_id async def assure_thread_exists_and_belongs_to_user( *, thread_id: str, user: User, session: AsyncSession, ) -> None: """Ensure that the given thread ID exists and belongs to the specified user. Args: thread_id: The unique identifier for the chat thread. user: The current user. session: The database session for querying. Raises: PermissionError: If the thread does not belong to the user. ValueError: If the thread does not exist. """ # Query the database for the thread mapping stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id) result = await session.execute(stmt) thread_mapping = result.scalar_one_or_none() # Check if thread exists if thread_mapping is None: logger.warning(f"Thread {thread_id} does not exist") raise ValueError(f"Thread {thread_id} does not exist") # Check if thread belongs to the user if thread_mapping.user_id != user.id: logger.warning( f"User {user.id} attempted to access thread {thread_id} " f"belonging to user {thread_mapping.user_id}" ) raise PermissionError( f"You do not have permission to access thread {thread_id}" ) logger.info(f"Thread {thread_id} verified for user {user.id}") async def update_thread_name( *, thread_id: str, user: User, new_thread_name: str, session: AsyncSession, ) -> None: """Update the name of an existing chat thread. This function performs security checks by including both threadId and userId in the WHERE clause of the UPDATE query, ensuring users can only update threads that belong to them. No separate permission check is needed. Args: thread_id: The unique identifier for the chat thread. user: The current user. new_thread_name: The new name to set for the thread. session: The database session for updating. Raises: ValueError: If the thread does not exist or does not belong to the user. """ logger.info( f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}" ) # Update the thread name and date_updated # Security check: WHERE clause includes both thread_id AND user_id stmt = ( update(UserThreadMapping) .where( UserThreadMapping.thread_id == thread_id, UserThreadMapping.user_id == user.id, ) .values(thread_name=new_thread_name, date_updated=datetime.now(timezone.utc)) ) result = await session.execute(stmt) await session.commit() # Check if any rows were affected if result.rowcount == 0: logger.warning( f"Failed to update thread {thread_id} for user {user.id} - " "thread does not exist or does not belong to user" ) raise ValueError( f"Thread {thread_id} does not exist or you do not have permission to access it" ) logger.info(f"Successfully updated thread {thread_id} name for user {user.id}") async def refresh_thread_date_updated( *, thread_id: str, user: User, session: AsyncSession, ) -> None: """Refresh the date_updated timestamp for an existing chat thread. This function performs security checks by including both threadId and userId in the WHERE clause of the UPDATE query, ensuring users can only update threads that belong to them. No separate permission check is needed. Args: thread_id: The unique identifier for the chat thread. user: The current user. session: The database session for updating. Raises: ValueError: If the thread does not exist or does not belong to the user. """ logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}") # Update the date_updated timestamp # Security check: WHERE clause includes both thread_id AND user_id stmt = ( update(UserThreadMapping) .where( UserThreadMapping.thread_id == thread_id, UserThreadMapping.user_id == user.id, ) .values(date_updated=datetime.now(timezone.utc)) ) result = await session.execute(stmt) await session.commit() # Check if any rows were affected if result.rowcount == 0: logger.warning( f"Failed to refresh thread {thread_id} for user {user.id} - " "thread does not exist or does not belong to user" ) raise ValueError( f"Thread {thread_id} does not exist or you do not have permission to access it" ) logger.info( f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}" ) async def post_message( *, thread_id: str, message: str, user: User, ) -> ChatMessageResponse: """Post a chat message to the chatbot and return the response. Args: thread_id: The unique identifier for the chat thread. message: The content of the chat message. user: The current user. Returns: The response containing the full chat message history and thread ID. """ logger.info(f"User {user.id} posted message to thread {thread_id}") # Get user permissions tool_ids = permissions.get_chatbot_tools(user_id=user.id) if not tool_ids: raise ValueError("User does not have permission to use any chatbot tools") model_name = permissions.get_chatbot_model(user_id=user.id) system_prompt = permissions.get_system_prompt(user_id=user.id) # Get tools from registry registry = get_registry() tools = registry.get_tool_instances(tool_ids=tool_ids) # Get model and checkpointer model = get_langchain_model(model_name=model_name) checkpointer = get_checkpointer() # Get context window size from config context_window_size = int( APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) ) # Create chatbot instance chatbot = await Chatbot.create( model=model, memory=checkpointer, system_prompt=system_prompt, tools=tools, context_window_size=context_window_size, ) # Send message to chatbot response = await chatbot.chat(message=message, chat_id=thread_id) # Parse the response to the correct format messages = [] for msg in response: # Determine the role of the message if isinstance(msg, HumanMessage): role = "user" elif isinstance(msg, AIMessage): role = "assistant" else: continue # Skip any other message types # Skip messages that are structured content, such as tool calls if not isinstance(msg.content, str): continue # Append message to chat history item = MessageItem( role=role, content=msg.content.strip(), ) messages.append(item) return ChatMessageResponse(thread_id=thread_id, messages=messages) async def post_message_stream( *, thread_id: str, message: str, user: User, ) -> AsyncIterator[str]: """Post a chat message to the chatbot and stream progress updates (SSE). Args: thread_id: The unique identifier for the chat thread. message: The content of the chat message. user: The current user. Yields: Server-Sent Events formatted strings containing status updates and final response. """ logger.info(f"User {user.id} streaming message to thread {thread_id}") try: # Get user permissions tool_ids = permissions.get_chatbot_tools(user_id=user.id) if not tool_ids: yield ( "data: " + json.dumps( { "type": "error", "message": "User does not have permission to use any chatbot tools", } ) + "\n\n" ) return model_name = permissions.get_chatbot_model(user_id=user.id) system_prompt = permissions.get_system_prompt(user_id=user.id) # Get tools from registry registry = get_registry() tools = registry.get_tool_instances(tool_ids=tool_ids) # Get model and checkpointer model = get_langchain_model(model_name=model_name) checkpointer = get_checkpointer() # Get context window size from config context_window_size = int( APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) ) # Create chatbot instance chatbot = await Chatbot.create( model=model, memory=checkpointer, system_prompt=system_prompt, tools=tools, context_window_size=context_window_size, ) # Stream events from chatbot async for event in chatbot.stream_events(message=message, chat_id=thread_id): etype = event.get("type") # Forward status updates if etype == "status": yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n" continue # Forward final response if etype == "final": response_from_event = event.get("response") or {} # Use the chat history from the final event (already normalized by stream_events) chat_history_payload = response_from_event.get("chat_history", []) if isinstance(chat_history_payload, list): # Convert to MessageItem format items: List[MessageItem] = [] for it in chat_history_payload: role = it.get("role") content = it.get("content", "") if role in ("user", "assistant") and content: items.append( MessageItem( role=role, content=content, ) ) response = ChatMessageResponse(thread_id=thread_id, messages=items) # Yield the final response and exit yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" return else: # Unexpected payload format - log warning and return empty history logger.warning( f"Unexpected chat_history format in final event: {type(chat_history_payload)}" ) response = ChatMessageResponse(thread_id=thread_id, messages=[]) yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" return # Forward error events if etype == "error": yield f"data: {json.dumps(event)}\n\n" return except Exception as e: error_msg = f"{type(e).__name__}: {str(e) or 'No error message provided'}" logger.error(f"Error in streaming chat: {error_msg}", exc_info=True) yield ( "data: " + json.dumps( { "type": "error", "message": f"An error occurred while processing your request: {error_msg}", } ) + "\n\n" ) async def get_thread_messages_from_langgraph( *, thread_id: str, app, ) -> List[dict]: """Retrieve and format messages from LangGraph checkpointer. Args: thread_id: The unique identifier for the chat thread. app: The compiled LangGraph app with checkpointer. Returns: List of message dicts with role, content, and timestamp. """ ROLE_MAP = {"human": "user", "ai": "assistant"} cfg = {"configurable": {"thread_id": thread_id}} state = await app.aget_state(cfg) messages = [] for msg in state.values.get("messages", []): # Skip system and tool messages - only include user and assistant if msg.type not in ["human", "ai"]: continue # Convert content to string if needed content = msg.content if isinstance(msg.content, str) else str(msg.content) messages.append( { "role": ROLE_MAP.get(msg.type, msg.type), "content": content, } ) return messages async def get_thread_detail_for_user( *, thread_id: str, user: User, session: AsyncSession, ) -> ThreadDetail: """Get detailed thread information with message history from LangGraph. Args: thread_id: The unique identifier for the chat thread. user: The current user. session: The database session for querying. Returns: ThreadDetail object with thread metadata and message history. Raises: PermissionError: If the thread does not belong to the user. ValueError: If the thread does not exist. """ logger.info(f"Getting thread detail for thread {thread_id} for user {user.id}") # Verify thread exists and belongs to user await assure_thread_exists_and_belongs_to_user( thread_id=thread_id, user=user, session=session ) # Get thread metadata from database stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id) result = await session.execute(stmt) thread_mapping = result.scalar_one() # Build the chatbot app to access LangGraph state # Use same approach as post_message for consistency tool_ids = permissions.get_chatbot_tools(user_id=user.id) if not tool_ids: raise ValueError("User does not have permission to use any chatbot tools") model_name = permissions.get_chatbot_model(user_id=user.id) system_prompt = permissions.get_system_prompt(user_id=user.id) # Get tools from registry registry = get_registry() tools = registry.get_tool_instances(tool_ids=tool_ids) # Get model and checkpointer model = get_langchain_model(model_name=model_name) checkpointer = get_checkpointer() # Get context window size from config context_window_size = int( APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) ) # Create chatbot instance chatbot = await Chatbot.create( model=model, memory=checkpointer, system_prompt=system_prompt, tools=tools, context_window_size=context_window_size, ) # Get messages from LangGraph checkpointer message_dicts = await get_thread_messages_from_langgraph( thread_id=thread_id, app=chatbot.app ) # Convert to MessageItem objects messages = [MessageItem(**m) for m in message_dicts] logger.info( f"Retrieved thread {thread_id} with {len(messages)} messages for user {user.id}" ) # Return ThreadDetail return ThreadDetail( thread_id=thread_id, date_created=thread_mapping.date_created.timestamp(), date_updated=thread_mapping.date_updated.timestamp(), messages=messages, )