From ba1daa2d7317d494de91b50e4f95374713701073 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Thu, 9 Oct 2025 16:56:27 +0200 Subject: [PATCH] feat: allow users to specify tools when posting messages --- modules/datamodels/datamodelChatbot.py | 4 + modules/features/chatBot/service.py | 129 ++++++++++++++++++++----- modules/routes/routeChatbot.py | 36 ++++++- 3 files changed, 144 insertions(+), 25 deletions(-) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py index d4bb2f3a..762996b3 100644 --- a/modules/datamodels/datamodelChatbot.py +++ b/modules/datamodels/datamodelChatbot.py @@ -20,6 +20,10 @@ class ChatMessageRequest(BaseModel, ModelMixin): None, description="Thread ID (creates new thread if not provided)" ) message: str = Field(..., description="User message content") + tools: Optional[List[str]] = Field( + None, + description="List of tool IDs to use. If not provided, all user's tools will be used", + ) class ChatMessageResponse(BaseModel, ModelMixin): diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index e4b63ff9..8239ba53 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -301,6 +301,7 @@ async def post_message( thread_id: str, message: str, user: User, + tool_ids: List[str], ) -> ChatMessageResponse: """Post a chat message to the chatbot and return the response. @@ -308,21 +309,19 @@ async def post_message( thread_id: The unique identifier for the chat thread. message: The content of the chat message. user: The current user. + tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools. Returns: The response containing the full chat message history and thread ID. """ - logger.info(f"User {user.id} posted message to thread {thread_id}") - - # Get user permissions - tool_ids = permissions.get_chatbot_tools(user_id=user.id) - if not tool_ids: - raise ValueError("User does not have permission to use any chatbot tools") + logger.info( + f"User {user.id} posted message to thread {thread_id} with {len(tool_ids)} tools" + ) model_name = permissions.get_chatbot_model(user_id=user.id) system_prompt = permissions.get_system_prompt(user_id=user.id) - # Get tools from registry + # Get tools from registry (empty list if no tools) registry = get_registry() tools = registry.get_tool_instances(tool_ids=tool_ids) @@ -377,6 +376,7 @@ async def post_message_stream( thread_id: str, message: str, user: User, + tool_ids: List[str], ) -> AsyncIterator[str]: """Post a chat message to the chatbot and stream progress updates (SSE). @@ -384,32 +384,20 @@ async def post_message_stream( thread_id: The unique identifier for the chat thread. message: The content of the chat message. user: The current user. + tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools. Yields: Server-Sent Events formatted strings containing status updates and final response. """ - logger.info(f"User {user.id} streaming message to thread {thread_id}") + logger.info( + f"User {user.id} streaming message to thread {thread_id} with {len(tool_ids)} tools" + ) try: - # Get user permissions - tool_ids = permissions.get_chatbot_tools(user_id=user.id) - if not tool_ids: - yield ( - "data: " - + json.dumps( - { - "type": "error", - "message": "User does not have permission to use any chatbot tools", - } - ) - + "\n\n" - ) - return - model_name = permissions.get_chatbot_model(user_id=user.id) system_prompt = permissions.get_system_prompt(user_id=user.id) - # Get tools from registry + # Get tools from registry (empty list if no tools) registry = get_registry() tools = registry.get_tool_instances(tool_ids=tool_ids) @@ -939,3 +927,96 @@ async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dic logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}") return tool_list + + +async def validate_and_get_tools_for_request( + *, + user_id: str, + requested_tool_ids: Optional[List[str]], + session: AsyncSession, +) -> List[str]: + """Validate and get tool IDs for a chat request. + + This function validates that the user has access to the requested tools. + If no tools are requested (None), it returns all tools the user has access to. + If an empty list is provided, it returns an empty list (no tools). + + Args: + user_id: The user ID making the request. + requested_tool_ids: Optional list of tool UUIDs (id field) requested by the user. + - None: Use all tools the user has access to + - []: Use no tools at all + - ["uuid1", "uuid2"]: Use only the specified tools + session: The database session for querying. + + Returns: + List of validated tool IDs (tool_id field, not UUID) that the user can use. + + Raises: + PermissionError: If the user requests tools they don't have access to. + ValueError: If the user has no tools available when trying to use all tools. + """ + from modules.features.chatBot.database import Tool, UserToolMapping + import uuid + + logger.info(f"Validating tools for user {user_id}") + + # If empty list is explicitly provided, return empty list (no tools) + if requested_tool_ids is not None and len(requested_tool_ids) == 0: + logger.info( + f"Empty tool list requested, chatbot will run without tools for user {user_id}" + ) + return [] + + # Get all tools the user has access to + stmt = ( + select(Tool) + .join(UserToolMapping, Tool.id == UserToolMapping.tool_id) + .where( + UserToolMapping.user_id == user_id, + UserToolMapping.is_active == True, + Tool.is_active == True, + ) + ) + result = await session.execute(stmt) + user_tools = result.scalars().all() + + # Create mappings for both UUID and tool_id + user_tool_ids_by_uuid = {str(tool.id): tool.tool_id for tool in user_tools} + user_tool_ids = set(user_tool_ids_by_uuid.values()) + + if not user_tool_ids: + logger.warning(f"User {user_id} has no tools available") + raise ValueError("User does not have access to any chatbot tools") + + # If no specific tools requested (None), return all user's tools + if requested_tool_ids is None: + logger.info( + f"No specific tools requested, returning all {len(user_tool_ids)} tools for user {user_id}" + ) + return list(user_tool_ids) + + # Convert requested UUIDs to tool_ids and validate access + requested_tool_ids_result = [] + unauthorized_uuids = [] + + for requested_uuid in requested_tool_ids: + if requested_uuid in user_tool_ids_by_uuid: + # User has access to this tool + requested_tool_ids_result.append(user_tool_ids_by_uuid[requested_uuid]) + else: + # User doesn't have access to this tool + unauthorized_uuids.append(requested_uuid) + + if unauthorized_uuids: + logger.warning( + f"User {user_id} requested unauthorized tool UUIDs: {unauthorized_uuids}" + ) + raise PermissionError( + f"You do not have access to the following tools: {', '.join(unauthorized_uuids)}" + ) + + logger.info( + f"Validated {len(requested_tool_ids_result)} requested tools for user {user_id}" + ) + return requested_tool_ids_result diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 65151017..a4757c84 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -62,6 +62,13 @@ async def post_chat_message_stream( Returns Server-Sent Events (SSE) stream with status updates and final response. """ try: + # Validate and get tools for the request + tool_ids = await chat_service.validate_and_get_tools_for_request( + user_id=currentUser.id, + requested_tool_ids=message_request.tools, + session=session, + ) + # Get or create thread using helper function thread_id = await get_or_create_thread_for_user( thread_id=message_request.thread_id, @@ -80,6 +87,7 @@ async def post_chat_message_stream( thread_id=thread_id, message=message_request.message, user=currentUser, + tool_ids=tool_ids, ), media_type="text/event-stream", headers={ @@ -88,6 +96,18 @@ async def post_chat_message_stream( }, ) + except PermissionError as e: + logger.error(f"Permission error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) + except ValueError as e: + logger.error(f"Validation error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) except Exception as e: logger.error( f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True @@ -114,6 +134,13 @@ async def post_chat_message( For streaming updates, use the /message/stream endpoint instead. """ try: + # Validate and get tools for the request + tool_ids = await chat_service.validate_and_get_tools_for_request( + user_id=currentUser.id, + requested_tool_ids=message_request.tools, + session=session, + ) + # Get or create thread using helper function thread_id = await get_or_create_thread_for_user( thread_id=message_request.thread_id, @@ -129,16 +156,23 @@ async def post_chat_message( thread_id=thread_id, message=message_request.message, user=currentUser, + tool_ids=tool_ids, ) return response - except ValueError as e: + except PermissionError as e: logger.error(f"Permission error: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(e) or "Permission denied", ) + except ValueError as e: + logger.error(f"Validation error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) except Exception as e: logger.error( f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True