gateway/modules/features/chatBot/service.py
2025-10-08 16:50:54 +02:00

702 lines
23 KiB
Python

"""Service layer for chatbot functionality."""
import json
import logging
from datetime import datetime, timezone
from typing import AsyncIterator, List, Optional
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
from modules.features.chatBot.utils.checkpointer import get_checkpointer
from modules.features.chatBot.utils.toolRegistry import get_registry
from modules.features.chatBot.utils import permissions
from modules.features.chatBot.database import UserThreadMapping
from modules.datamodels.datamodelChatbot import (
MessageItem,
ChatMessageResponse,
ThreadSummary,
ThreadDetail,
)
from modules.datamodels.datamodelUam import User
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
async def get_all_threads_for_user(
*,
user: User,
session: AsyncSession,
) -> List[ThreadSummary]:
"""Get all chat threads for a user.
Args:
user: The current user.
session: The database session for querying.
Returns:
List of ThreadSummary objects sorted by date_updated (newest first).
Returns empty list if no threads found.
"""
logger.info(f"Fetching all threads for user {user.id}")
# Query all threads for this user, ordered by date_updated descending
stmt = (
select(UserThreadMapping)
.where(UserThreadMapping.user_id == user.id)
.order_by(UserThreadMapping.date_updated.desc())
)
result = await session.execute(stmt)
thread_mappings = result.scalars().all()
# Convert to ThreadSummary objects
threads = []
for mapping in thread_mappings:
thread_summary = ThreadSummary(
thread_id=mapping.thread_id,
thread_name=mapping.thread_name,
date_created=mapping.date_created.timestamp(),
date_updated=mapping.date_updated.timestamp(),
)
threads.append(thread_summary)
logger.info(f"Found {len(threads)} threads for user {user.id}")
return threads
async def save_thread_for_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
) -> None:
"""Save a new chat thread mapping for the user.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for saving.
thread_name: The name of the chat thread. Defaults to "New Chat".
"""
logger.info(f"Saving new thread {thread_id} for user {user.id}")
# Create new mapping entry
new_mapping = UserThreadMapping(
user_id=user.id,
thread_id=thread_id,
thread_name=thread_name,
)
session.add(new_mapping)
await session.commit()
logger.info(f"Successfully saved thread {thread_id} for user {user.id}")
async def get_or_create_thread_for_user(
*,
thread_id: Optional[str],
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
refresh_date_updated: bool = False,
) -> str:
"""Get an existing thread or create a new one for the user.
If thread_id is provided, verifies it exists and belongs to the user.
If thread_id is None, generates a new thread_id and saves it.
Args:
thread_id: Optional thread identifier. If None, creates a new thread.
user: The current user.
session: The database session for querying/saving.
thread_name: The name for the thread if creating new. Defaults to "New Chat".
refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False.
Returns:
The thread_id to use (either the provided one or newly created).
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the provided thread_id does not exist.
"""
if thread_id:
# If the user provided a thread_id, verify it exists and belongs to them
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id, user=user, session=session
)
logger.info(f"Using existing thread {thread_id} for user {user.id}")
# Refresh date_updated if requested
if refresh_date_updated:
await refresh_thread_date_updated(
thread_id=thread_id, user=user, session=session
)
return thread_id
else:
# Generate new thread_id if the user did not provide a thread_id
import uuid
new_thread_id = f"thread_{uuid.uuid4()}"
await save_thread_for_user(
thread_id=new_thread_id,
user=user,
session=session,
thread_name=thread_name,
)
logger.info(f"Created new thread {new_thread_id} for user {user.id}")
return new_thread_id
async def assure_thread_exists_and_belongs_to_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Ensure that the given thread ID exists and belongs to the specified user.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for querying.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
# Query the database for the thread mapping
stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id)
result = await session.execute(stmt)
thread_mapping = result.scalar_one_or_none()
# Check if thread exists
if thread_mapping is None:
logger.warning(f"Thread {thread_id} does not exist")
raise ValueError(f"Thread {thread_id} does not exist")
# Check if thread belongs to the user
if thread_mapping.user_id != user.id:
logger.warning(
f"User {user.id} attempted to access thread {thread_id} "
f"belonging to user {thread_mapping.user_id}"
)
raise PermissionError(
f"You do not have permission to access thread {thread_id}"
)
logger.info(f"Thread {thread_id} verified for user {user.id}")
async def update_thread_name(
*,
thread_id: str,
user: User,
new_thread_name: str,
session: AsyncSession,
) -> None:
"""Update the name of an existing chat thread.
This function performs security checks by including both threadId and userId
in the WHERE clause of the UPDATE query, ensuring users can only update
threads that belong to them. No separate permission check is needed.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
new_thread_name: The new name to set for the thread.
session: The database session for updating.
Raises:
ValueError: If the thread does not exist or does not belong to the user.
"""
logger.info(
f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}"
)
# Update the thread name and date_updated
# Security check: WHERE clause includes both thread_id AND user_id
stmt = (
update(UserThreadMapping)
.where(
UserThreadMapping.thread_id == thread_id,
UserThreadMapping.user_id == user.id,
)
.values(thread_name=new_thread_name, date_updated=datetime.now(timezone.utc))
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were affected
if result.rowcount == 0:
logger.warning(
f"Failed to update thread {thread_id} for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(f"Successfully updated thread {thread_id} name for user {user.id}")
async def refresh_thread_date_updated(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Refresh the date_updated timestamp for an existing chat thread.
This function performs security checks by including both threadId and userId
in the WHERE clause of the UPDATE query, ensuring users can only update
threads that belong to them. No separate permission check is needed.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for updating.
Raises:
ValueError: If the thread does not exist or does not belong to the user.
"""
logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}")
# Update the date_updated timestamp
# Security check: WHERE clause includes both thread_id AND user_id
stmt = (
update(UserThreadMapping)
.where(
UserThreadMapping.thread_id == thread_id,
UserThreadMapping.user_id == user.id,
)
.values(date_updated=datetime.now(timezone.utc))
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were affected
if result.rowcount == 0:
logger.warning(
f"Failed to refresh thread {thread_id} for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(
f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}"
)
async def post_message(
*,
thread_id: str,
message: str,
user: User,
) -> ChatMessageResponse:
"""Post a chat message to the chatbot and return the response.
Args:
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
Returns:
The response containing the full chat message history and thread ID.
"""
logger.info(f"User {user.id} posted message to thread {thread_id}")
# Get user permissions
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
raise ValueError("User does not have permission to use any chatbot tools")
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Send message to chatbot
response = await chatbot.chat(message=message, chat_id=thread_id)
# Parse the response to the correct format
messages = []
for msg in response:
# Determine the role of the message
if isinstance(msg, HumanMessage):
role = "user"
elif isinstance(msg, AIMessage):
role = "assistant"
else:
continue # Skip any other message types
# Skip messages that are structured content, such as tool calls
if not isinstance(msg.content, str):
continue
# Append message to chat history
item = MessageItem(
role=role,
content=msg.content.strip(),
)
messages.append(item)
return ChatMessageResponse(thread_id=thread_id, messages=messages)
async def post_message_stream(
*,
thread_id: str,
message: str,
user: User,
) -> AsyncIterator[str]:
"""Post a chat message to the chatbot and stream progress updates (SSE).
Args:
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
Yields:
Server-Sent Events formatted strings containing status updates and final response.
"""
logger.info(f"User {user.id} streaming message to thread {thread_id}")
try:
# Get user permissions
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
yield (
"data: "
+ json.dumps(
{
"type": "error",
"message": "User does not have permission to use any chatbot tools",
}
)
+ "\n\n"
)
return
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Stream events from chatbot
async for event in chatbot.stream_events(message=message, chat_id=thread_id):
etype = event.get("type")
# Forward status updates
if etype == "status":
yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n"
continue
# Forward final response
if etype == "final":
response_from_event = event.get("response") or {}
# Use the chat history from the final event (already normalized by stream_events)
chat_history_payload = response_from_event.get("chat_history", [])
if isinstance(chat_history_payload, list):
# Convert to MessageItem format
items: List[MessageItem] = []
for it in chat_history_payload:
role = it.get("role")
content = it.get("content", "")
if role in ("user", "assistant") and content:
items.append(
MessageItem(
role=role,
content=content,
)
)
response = ChatMessageResponse(thread_id=thread_id, messages=items)
# Yield the final response and exit
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
return
else:
# Unexpected payload format - log warning and return empty history
logger.warning(
f"Unexpected chat_history format in final event: {type(chat_history_payload)}"
)
response = ChatMessageResponse(thread_id=thread_id, messages=[])
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
return
# Forward error events
if etype == "error":
yield f"data: {json.dumps(event)}\n\n"
return
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e) or 'No error message provided'}"
logger.error(f"Error in streaming chat: {error_msg}", exc_info=True)
yield (
"data: "
+ json.dumps(
{
"type": "error",
"message": f"An error occurred while processing your request: {error_msg}",
}
)
+ "\n\n"
)
async def get_thread_messages_from_langgraph(
*,
thread_id: str,
app,
) -> List[dict]:
"""Retrieve and format messages from LangGraph checkpointer.
Args:
thread_id: The unique identifier for the chat thread.
app: The compiled LangGraph app with checkpointer.
Returns:
List of message dicts with role, content, and timestamp.
"""
ROLE_MAP = {"human": "user", "ai": "assistant"}
cfg = {"configurable": {"thread_id": thread_id}}
state = await app.aget_state(cfg)
messages = []
for msg in state.values.get("messages", []):
# Skip system and tool messages - only include user and assistant
if msg.type not in ["human", "ai"]:
continue
# Convert content to string if needed
content = msg.content if isinstance(msg.content, str) else str(msg.content)
messages.append(
{
"role": ROLE_MAP.get(msg.type, msg.type),
"content": content,
}
)
return messages
async def get_thread_detail_for_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> ThreadDetail:
"""Get detailed thread information with message history from LangGraph.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for querying.
Returns:
ThreadDetail object with thread metadata and message history.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
logger.info(f"Getting thread detail for thread {thread_id} for user {user.id}")
# Verify thread exists and belongs to user
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id, user=user, session=session
)
# Get thread metadata from database
stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id)
result = await session.execute(stmt)
thread_mapping = result.scalar_one()
# Build the chatbot app to access LangGraph state
# Use same approach as post_message for consistency
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
raise ValueError("User does not have permission to use any chatbot tools")
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Get messages from LangGraph checkpointer
message_dicts = await get_thread_messages_from_langgraph(
thread_id=thread_id, app=chatbot.app
)
# Convert to MessageItem objects
messages = [MessageItem(**m) for m in message_dicts]
logger.info(
f"Retrieved thread {thread_id} with {len(messages)} messages for user {user.id}"
)
# Return ThreadDetail
return ThreadDetail(
thread_id=thread_id,
date_created=thread_mapping.date_created.timestamp(),
date_updated=thread_mapping.date_updated.timestamp(),
messages=messages,
)
async def delete_thread_for_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Delete a chat thread for a user from both LangGraph and the database.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for deleting.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
logger.info(f"Deleting thread {thread_id} for user {user.id}")
# Verify thread exists and belongs to user
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id, user=user, session=session
)
# Build the chatbot app to access the checkpointer
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
raise ValueError("User does not have permission to use any chatbot tools")
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Delete from LangGraph checkpointer
try:
await chatbot.app.checkpointer.adelete_thread(thread_id)
logger.info(f"Deleted thread {thread_id} from LangGraph checkpointer")
except Exception as e:
logger.error(
f"Failed to delete thread {thread_id} from LangGraph: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise ValueError(
f"Failed to delete thread from LangGraph: {type(e).__name__}: {str(e)}"
)
# Delete from database
stmt = delete(UserThreadMapping).where(
UserThreadMapping.thread_id == thread_id,
UserThreadMapping.user_id == user.id,
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were deleted
if result.rowcount == 0:
logger.warning(
f"Failed to delete thread {thread_id} from database for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(f"Successfully deleted thread {thread_id} for user {user.id}")