gateway/modules/features/chatBot/service.py
2025-10-09 16:56:27 +02:00

1022 lines
33 KiB
Python

"""Service layer for chatbot functionality."""
import json
import logging
from datetime import datetime, timezone
from typing import AsyncIterator, List, Optional
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
from modules.features.chatBot.utils.checkpointer import get_checkpointer
from modules.features.chatBot.utils.toolRegistry import get_registry
from modules.features.chatBot.utils import permissions
from modules.features.chatBot.database import UserThreadMapping
from modules.datamodels.datamodelChatbot import (
MessageItem,
ChatMessageResponse,
ThreadSummary,
ThreadDetail,
)
from modules.datamodels.datamodelUam import User
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langgraph.graph import StateGraph, MessagesState, START, END
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
async def get_all_threads_for_user(
*,
user: User,
session: AsyncSession,
) -> List[ThreadSummary]:
"""Get all chat threads for a user.
Args:
user: The current user.
session: The database session for querying.
Returns:
List of ThreadSummary objects sorted by date_updated (newest first).
Returns empty list if no threads found.
"""
logger.info(f"Fetching all threads for user {user.id}")
# Query all threads for this user, ordered by date_updated descending
stmt = (
select(UserThreadMapping)
.where(UserThreadMapping.user_id == user.id)
.order_by(UserThreadMapping.date_updated.desc())
)
result = await session.execute(stmt)
thread_mappings = result.scalars().all()
# Convert to ThreadSummary objects
threads = []
for mapping in thread_mappings:
thread_summary = ThreadSummary(
thread_id=mapping.thread_id,
thread_name=mapping.thread_name,
date_created=mapping.date_created.timestamp(),
date_updated=mapping.date_updated.timestamp(),
)
threads.append(thread_summary)
logger.info(f"Found {len(threads)} threads for user {user.id}")
return threads
async def save_thread_for_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
) -> None:
"""Save a new chat thread mapping for the user.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for saving.
thread_name: The name of the chat thread. Defaults to "New Chat".
"""
logger.info(f"Saving new thread {thread_id} for user {user.id}")
# Create new mapping entry
new_mapping = UserThreadMapping(
user_id=user.id,
thread_id=thread_id,
thread_name=thread_name,
)
session.add(new_mapping)
await session.commit()
logger.info(f"Successfully saved thread {thread_id} for user {user.id}")
async def get_or_create_thread_for_user(
*,
thread_id: Optional[str],
user: User,
session: AsyncSession,
thread_name: str = "New Chat",
refresh_date_updated: bool = False,
) -> str:
"""Get an existing thread or create a new one for the user.
If thread_id is provided, verifies it exists and belongs to the user.
If thread_id is None, generates a new thread_id and saves it.
Args:
thread_id: Optional thread identifier. If None, creates a new thread.
user: The current user.
session: The database session for querying/saving.
thread_name: The name for the thread if creating new. Defaults to "New Chat".
refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False.
Returns:
The thread_id to use (either the provided one or newly created).
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the provided thread_id does not exist.
"""
if thread_id:
# If the user provided a thread_id, verify it exists and belongs to them
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id, user=user, session=session
)
logger.info(f"Using existing thread {thread_id} for user {user.id}")
# Refresh date_updated if requested
if refresh_date_updated:
await refresh_thread_date_updated(
thread_id=thread_id, user=user, session=session
)
return thread_id
else:
# Generate new thread_id if the user did not provide a thread_id
import uuid
new_thread_id = f"thread_{uuid.uuid4()}"
await save_thread_for_user(
thread_id=new_thread_id,
user=user,
session=session,
thread_name=thread_name,
)
logger.info(f"Created new thread {new_thread_id} for user {user.id}")
return new_thread_id
async def assure_thread_exists_and_belongs_to_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Ensure that the given thread ID exists and belongs to the specified user.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for querying.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
# Query the database for the thread mapping
stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id)
result = await session.execute(stmt)
thread_mapping = result.scalar_one_or_none()
# Check if thread exists
if thread_mapping is None:
logger.warning(f"Thread {thread_id} does not exist")
raise ValueError(f"Thread {thread_id} does not exist")
# Check if thread belongs to the user
if thread_mapping.user_id != user.id:
logger.warning(
f"User {user.id} attempted to access thread {thread_id} "
f"belonging to user {thread_mapping.user_id}"
)
raise PermissionError(
f"You do not have permission to access thread {thread_id}"
)
logger.info(f"Thread {thread_id} verified for user {user.id}")
async def update_thread_name(
*,
thread_id: str,
user: User,
new_thread_name: str,
session: AsyncSession,
) -> None:
"""Update the name of an existing chat thread.
This function performs security checks by including both threadId and userId
in the WHERE clause of the UPDATE query, ensuring users can only update
threads that belong to them. No separate permission check is needed.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
new_thread_name: The new name to set for the thread.
session: The database session for updating.
Raises:
ValueError: If the thread does not exist or does not belong to the user.
"""
logger.info(
f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}"
)
# Update the thread name and date_updated
# Security check: WHERE clause includes both thread_id AND user_id
stmt = (
update(UserThreadMapping)
.where(
UserThreadMapping.thread_id == thread_id,
UserThreadMapping.user_id == user.id,
)
.values(thread_name=new_thread_name, date_updated=datetime.now(timezone.utc))
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were affected
if result.rowcount == 0:
logger.warning(
f"Failed to update thread {thread_id} for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(f"Successfully updated thread {thread_id} name for user {user.id}")
async def refresh_thread_date_updated(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Refresh the date_updated timestamp for an existing chat thread.
This function performs security checks by including both threadId and userId
in the WHERE clause of the UPDATE query, ensuring users can only update
threads that belong to them. No separate permission check is needed.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for updating.
Raises:
ValueError: If the thread does not exist or does not belong to the user.
"""
logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}")
# Update the date_updated timestamp
# Security check: WHERE clause includes both thread_id AND user_id
stmt = (
update(UserThreadMapping)
.where(
UserThreadMapping.thread_id == thread_id,
UserThreadMapping.user_id == user.id,
)
.values(date_updated=datetime.now(timezone.utc))
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were affected
if result.rowcount == 0:
logger.warning(
f"Failed to refresh thread {thread_id} for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(
f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}"
)
async def post_message(
*,
thread_id: str,
message: str,
user: User,
tool_ids: List[str],
) -> ChatMessageResponse:
"""Post a chat message to the chatbot and return the response.
Args:
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
Returns:
The response containing the full chat message history and thread ID.
"""
logger.info(
f"User {user.id} posted message to thread {thread_id} with {len(tool_ids)} tools"
)
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry (empty list if no tools)
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Send message to chatbot
response = await chatbot.chat(message=message, chat_id=thread_id)
# Parse the response to the correct format
messages = []
for msg in response:
# Determine the role of the message
if isinstance(msg, HumanMessage):
role = "user"
elif isinstance(msg, AIMessage):
role = "assistant"
else:
continue # Skip any other message types
# Skip messages that are structured content, such as tool calls
if not isinstance(msg.content, str):
continue
# Append message to chat history
item = MessageItem(
role=role,
content=msg.content.strip(),
)
messages.append(item)
return ChatMessageResponse(thread_id=thread_id, messages=messages)
async def post_message_stream(
*,
thread_id: str,
message: str,
user: User,
tool_ids: List[str],
) -> AsyncIterator[str]:
"""Post a chat message to the chatbot and stream progress updates (SSE).
Args:
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
Yields:
Server-Sent Events formatted strings containing status updates and final response.
"""
logger.info(
f"User {user.id} streaming message to thread {thread_id} with {len(tool_ids)} tools"
)
try:
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry (empty list if no tools)
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Stream events from chatbot
async for event in chatbot.stream_events(message=message, chat_id=thread_id):
etype = event.get("type")
# Forward status updates
if etype == "status":
yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n"
continue
# Forward final response
if etype == "final":
response_from_event = event.get("response") or {}
# Use the chat history from the final event (already normalized by stream_events)
chat_history_payload = response_from_event.get("chat_history", [])
if isinstance(chat_history_payload, list):
# Convert to MessageItem format
items: List[MessageItem] = []
for it in chat_history_payload:
role = it.get("role")
content = it.get("content", "")
if role in ("user", "assistant") and content:
items.append(
MessageItem(
role=role,
content=content,
)
)
response = ChatMessageResponse(thread_id=thread_id, messages=items)
# Yield the final response and exit
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
return
else:
# Unexpected payload format - log warning and return empty history
logger.warning(
f"Unexpected chat_history format in final event: {type(chat_history_payload)}"
)
response = ChatMessageResponse(thread_id=thread_id, messages=[])
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
return
# Forward error events
if etype == "error":
yield f"data: {json.dumps(event)}\n\n"
return
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e) or 'No error message provided'}"
logger.error(f"Error in streaming chat: {error_msg}", exc_info=True)
yield (
"data: "
+ json.dumps(
{
"type": "error",
"message": f"An error occurred while processing your request: {error_msg}",
}
)
+ "\n\n"
)
# Module-level singleton for minimal app used to read thread state
_MINIMAL_APP = None
def _build_minimal_app(*, checkpointer):
"""Build a minimal LangGraph app for reading thread state.
This creates a valid graph with a no-op node that we never actually run.
LangGraph requires a valid graph structure (with edges from START) to compile,
even though we only use it to call aget_state() to read from the checkpointer.
Args:
checkpointer: The checkpointer to attach to the graph.
Returns:
A compiled StateGraph that can be used to read thread state.
"""
graph = StateGraph(MessagesState)
# No-op node that returns the state unchanged
def noop(state: dict) -> dict:
return state
graph.add_node("noop", noop)
graph.add_edge(START, "noop")
graph.add_edge("noop", END)
return graph.compile(checkpointer=checkpointer)
def _get_minimal_app():
"""Get the module-level singleton minimal app.
Returns:
The cached minimal app, building it on first access.
"""
global _MINIMAL_APP
if _MINIMAL_APP is None:
_MINIMAL_APP = _build_minimal_app(checkpointer=get_checkpointer())
return _MINIMAL_APP
async def get_thread_messages_from_langgraph(
*,
thread_id: str,
) -> List[dict]:
"""Retrieve and format messages from LangGraph checkpointer.
Args:
thread_id: The unique identifier for the chat thread.
Returns:
List of message dicts with role and content.
"""
ROLE_MAP = {"human": "user", "ai": "assistant"}
# Get the minimal app (singleton, built once)
app = _get_minimal_app()
cfg = {"configurable": {"thread_id": thread_id}}
state = await app.aget_state(cfg)
messages = []
for msg in state.values.get("messages", []):
# Skip system and tool messages - only include user and assistant
if msg.type not in ["human", "ai"]:
continue
# Skip messages with non-string content (e.g., tool calls)
if not isinstance(msg.content, str):
continue
messages.append(
{
"role": ROLE_MAP.get(msg.type, msg.type),
"content": msg.content,
}
)
return messages
async def get_thread_detail_for_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> ThreadDetail:
"""Get detailed thread information with message history from LangGraph.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for querying.
Returns:
ThreadDetail object with thread metadata and message history.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
logger.info(f"Getting thread detail for thread {thread_id} for user {user.id}")
# Verify thread exists and belongs to user
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id, user=user, session=session
)
# Get thread metadata from database
stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id)
result = await session.execute(stmt)
thread_mapping = result.scalar_one()
# Get messages from LangGraph checkpointer (optimized - no full chatbot needed)
message_dicts = await get_thread_messages_from_langgraph(thread_id=thread_id)
# Convert to MessageItem objects
messages = [MessageItem(**m) for m in message_dicts]
logger.info(
f"Retrieved thread {thread_id} with {len(messages)} messages for user {user.id}"
)
# Return ThreadDetail
return ThreadDetail(
thread_id=thread_id,
date_created=thread_mapping.date_created.timestamp(),
date_updated=thread_mapping.date_updated.timestamp(),
messages=messages,
)
async def delete_thread_for_user(
*,
thread_id: str,
user: User,
session: AsyncSession,
) -> None:
"""Delete a chat thread for a user from both LangGraph and the database.
Args:
thread_id: The unique identifier for the chat thread.
user: The current user.
session: The database session for deleting.
Raises:
PermissionError: If the thread does not belong to the user.
ValueError: If the thread does not exist.
"""
logger.info(f"Deleting thread {thread_id} for user {user.id}")
# Verify thread exists and belongs to user
await assure_thread_exists_and_belongs_to_user(
thread_id=thread_id, user=user, session=session
)
# Delete from LangGraph checkpointer (optimized - no app/tools/model needed)
checkpointer = get_checkpointer()
try:
await checkpointer.adelete_thread(thread_id)
logger.info(f"Deleted thread {thread_id} from LangGraph checkpointer")
except Exception as e:
logger.error(
f"Failed to delete thread {thread_id} from LangGraph: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise ValueError(
f"Failed to delete thread from LangGraph: {type(e).__name__}: {str(e)}"
)
# Delete from database
stmt = delete(UserThreadMapping).where(
UserThreadMapping.thread_id == thread_id,
UserThreadMapping.user_id == user.id,
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were deleted
if result.rowcount == 0:
logger.warning(
f"Failed to delete thread {thread_id} from database for user {user.id} - "
"thread does not exist or does not belong to user"
)
raise ValueError(
f"Thread {thread_id} does not exist or you do not have permission to access it"
)
logger.info(f"Successfully deleted thread {thread_id} for user {user.id}")
# Tool Management Functions
async def get_all_tools(*, session: AsyncSession) -> List[dict]:
"""Get all tools from the database.
Args:
session: The database session for querying.
Returns:
List of tool dictionaries with all tool information.
"""
from modules.features.chatBot.database import Tool
logger.info("Fetching all tools from database")
stmt = select(Tool).order_by(Tool.category, Tool.name)
result = await session.execute(stmt)
tools = result.scalars().all()
tool_list = []
for tool in tools:
tool_dict = {
"id": str(tool.id),
"tool_id": tool.tool_id,
"name": tool.name,
"label": tool.label,
"category": tool.category,
"description": tool.description,
"is_active": tool.is_active,
"date_created": tool.date_created.timestamp(),
"date_updated": tool.date_updated.timestamp(),
}
tool_list.append(tool_dict)
logger.info(f"Retrieved {len(tool_list)} tools from database")
return tool_list
async def grant_tool_to_user(
*, user_id: str, tool_id: str, session: AsyncSession
) -> None:
"""Grant a tool to a user.
Args:
user_id: The user ID to grant the tool to.
tool_id: The tool UUID from the tools table.
session: The database session for querying/updating.
Raises:
ValueError: If the tool doesn't exist, is not active, or user already has the tool.
"""
from modules.features.chatBot.database import Tool, UserToolMapping
import uuid
logger.info(f"Granting tool {tool_id} to user {user_id}")
# Convert tool_id string to UUID
try:
tool_uuid = uuid.UUID(tool_id)
except ValueError:
raise ValueError(f"Invalid tool ID format: {tool_id}")
# Check if tool exists and is active
stmt = select(Tool).where(Tool.id == tool_uuid)
result = await session.execute(stmt)
tool = result.scalar_one_or_none()
if tool is None:
raise ValueError(f"Tool with ID {tool_id} does not exist")
if not tool.is_active:
raise ValueError(
f"Cannot grant inactive tool '{tool.label}' (tool_id: {tool.tool_id}). "
f"Please activate the tool first before granting it to users."
)
# Check if user already has this tool
stmt = select(UserToolMapping).where(
UserToolMapping.user_id == user_id, UserToolMapping.tool_id == tool_uuid
)
result = await session.execute(stmt)
existing_mapping = result.scalar_one_or_none()
if existing_mapping is not None:
raise ValueError(
f"User {user_id} already has access to tool '{tool.label}' (tool_id: {tool.tool_id})"
)
# Create new mapping
new_mapping = UserToolMapping(
user_id=user_id,
tool_id=tool_uuid,
is_active=True,
)
session.add(new_mapping)
await session.commit()
logger.info(f"Successfully granted tool {tool_id} ({tool.label}) to user {user_id}")
async def revoke_tool_from_user(
*, user_id: str, tool_id: str, session: AsyncSession
) -> None:
"""Revoke a tool from a user by deleting the mapping.
Args:
user_id: The user ID to revoke the tool from.
tool_id: The tool UUID from the tools table.
session: The database session for deleting.
Raises:
ValueError: If the mapping doesn't exist.
"""
from modules.features.chatBot.database import UserToolMapping
import uuid
logger.info(f"Revoking tool {tool_id} from user {user_id}")
# Convert tool_id string to UUID
try:
tool_uuid = uuid.UUID(tool_id)
except ValueError:
raise ValueError(f"Invalid tool ID format: {tool_id}")
# Delete the mapping
stmt = delete(UserToolMapping).where(
UserToolMapping.user_id == user_id, UserToolMapping.tool_id == tool_uuid
)
result = await session.execute(stmt)
await session.commit()
# Check if any rows were deleted
if result.rowcount == 0:
raise ValueError(
f"User {user_id} does not have access to tool {tool_id}, or the mapping does not exist"
)
logger.info(f"Successfully revoked tool {tool_id} from user {user_id}")
async def update_tool(
*,
tool_id: str,
label: Optional[str],
description: Optional[str],
session: AsyncSession,
) -> List[str]:
"""Update a tool's label and/or description.
Args:
tool_id: The tool UUID to update.
label: Optional new label for the tool.
description: Optional new description for the tool.
session: The database session for updating.
Returns:
List of updated field names.
Raises:
ValueError: If the tool doesn't exist or no fields provided to update.
"""
from modules.features.chatBot.database import Tool
import uuid
logger.info(f"Updating tool {tool_id}")
# Validate that at least one field is provided
if label is None and description is None:
raise ValueError("At least one field (label or description) must be provided")
# Convert tool_id string to UUID
try:
tool_uuid = uuid.UUID(tool_id)
except ValueError:
raise ValueError(f"Invalid tool ID format: {tool_id}")
# Check if tool exists
stmt = select(Tool).where(Tool.id == tool_uuid)
result = await session.execute(stmt)
tool = result.scalar_one_or_none()
if tool is None:
raise ValueError(f"Tool with ID {tool_id} does not exist")
# Build update values
update_values = {"date_updated": datetime.now(timezone.utc)}
updated_fields = []
if label is not None:
update_values["label"] = label
updated_fields.append("label")
if description is not None:
update_values["description"] = description
updated_fields.append("description")
# Update the tool
stmt = update(Tool).where(Tool.id == tool_uuid).values(**update_values)
await session.execute(stmt)
await session.commit()
logger.info(f"Successfully updated tool {tool_id}, fields: {updated_fields}")
return updated_fields
async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dict]:
"""Get all tools granted to a specific user.
Args:
user_id: The user ID to get tools for.
session: The database session for querying.
Returns:
List of tool dictionaries with all tool information.
"""
from modules.features.chatBot.database import Tool, UserToolMapping
logger.info(f"Fetching tools for user {user_id}")
# Query tools that are granted to the user
# Join UserToolMapping with Tool table
# Filter by user_id and active status
stmt = (
select(Tool)
.join(UserToolMapping, Tool.id == UserToolMapping.tool_id)
.where(
UserToolMapping.user_id == user_id,
UserToolMapping.is_active == True,
Tool.is_active == True,
)
.order_by(Tool.category, Tool.name)
)
result = await session.execute(stmt)
tools = result.scalars().all()
tool_list = []
for tool in tools:
tool_dict = {
"id": str(tool.id),
"tool_id": tool.tool_id,
"name": tool.name,
"label": tool.label,
"category": tool.category,
"description": tool.description,
"is_active": tool.is_active,
"date_created": tool.date_created.timestamp(),
"date_updated": tool.date_updated.timestamp(),
}
tool_list.append(tool_dict)
logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}")
return tool_list
async def validate_and_get_tools_for_request(
*,
user_id: str,
requested_tool_ids: Optional[List[str]],
session: AsyncSession,
) -> List[str]:
"""Validate and get tool IDs for a chat request.
This function validates that the user has access to the requested tools.
If no tools are requested (None), it returns all tools the user has access to.
If an empty list is provided, it returns an empty list (no tools).
Args:
user_id: The user ID making the request.
requested_tool_ids: Optional list of tool UUIDs (id field) requested by the user.
- None: Use all tools the user has access to
- []: Use no tools at all
- ["uuid1", "uuid2"]: Use only the specified tools
session: The database session for querying.
Returns:
List of validated tool IDs (tool_id field, not UUID) that the user can use.
Raises:
PermissionError: If the user requests tools they don't have access to.
ValueError: If the user has no tools available when trying to use all tools.
"""
from modules.features.chatBot.database import Tool, UserToolMapping
import uuid
logger.info(f"Validating tools for user {user_id}")
# If empty list is explicitly provided, return empty list (no tools)
if requested_tool_ids is not None and len(requested_tool_ids) == 0:
logger.info(
f"Empty tool list requested, chatbot will run without tools for user {user_id}"
)
return []
# Get all tools the user has access to
stmt = (
select(Tool)
.join(UserToolMapping, Tool.id == UserToolMapping.tool_id)
.where(
UserToolMapping.user_id == user_id,
UserToolMapping.is_active == True,
Tool.is_active == True,
)
)
result = await session.execute(stmt)
user_tools = result.scalars().all()
# Create mappings for both UUID and tool_id
user_tool_ids_by_uuid = {str(tool.id): tool.tool_id for tool in user_tools}
user_tool_ids = set(user_tool_ids_by_uuid.values())
if not user_tool_ids:
logger.warning(f"User {user_id} has no tools available")
raise ValueError("User does not have access to any chatbot tools")
# If no specific tools requested (None), return all user's tools
if requested_tool_ids is None:
logger.info(
f"No specific tools requested, returning all {len(user_tool_ids)} tools for user {user_id}"
)
return list(user_tool_ids)
# Convert requested UUIDs to tool_ids and validate access
requested_tool_ids_result = []
unauthorized_uuids = []
for requested_uuid in requested_tool_ids:
if requested_uuid in user_tool_ids_by_uuid:
# User has access to this tool
requested_tool_ids_result.append(user_tool_ids_by_uuid[requested_uuid])
else:
# User doesn't have access to this tool
unauthorized_uuids.append(requested_uuid)
if unauthorized_uuids:
logger.warning(
f"User {user_id} requested unauthorized tool UUIDs: {unauthorized_uuids}"
)
raise PermissionError(
f"You do not have access to the following tools: {', '.join(unauthorized_uuids)}"
)
logger.info(
f"Validated {len(requested_tool_ids_result)} requested tools for user {user_id}"
)
return requested_tool_ids_result