feat: allow users to specify tools when posting messages

This commit is contained in:
Christopher Gondek 2025-10-09 16:56:27 +02:00
parent 2b5d7506d0
commit ba1daa2d73
3 changed files with 144 additions and 25 deletions

View file

@ -20,6 +20,10 @@ class ChatMessageRequest(BaseModel, ModelMixin):
None, description="Thread ID (creates new thread if not provided)"
)
message: str = Field(..., description="User message content")
tools: Optional[List[str]] = Field(
None,
description="List of tool IDs to use. If not provided, all user's tools will be used",
)
class ChatMessageResponse(BaseModel, ModelMixin):

View file

@ -301,6 +301,7 @@ async def post_message(
thread_id: str,
message: str,
user: User,
tool_ids: List[str],
) -> ChatMessageResponse:
"""Post a chat message to the chatbot and return the response.
@ -308,21 +309,19 @@ async def post_message(
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
Returns:
The response containing the full chat message history and thread ID.
"""
logger.info(f"User {user.id} posted message to thread {thread_id}")
# Get user permissions
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
raise ValueError("User does not have permission to use any chatbot tools")
logger.info(
f"User {user.id} posted message to thread {thread_id} with {len(tool_ids)} tools"
)
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
# Get tools from registry (empty list if no tools)
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
@ -377,6 +376,7 @@ async def post_message_stream(
thread_id: str,
message: str,
user: User,
tool_ids: List[str],
) -> AsyncIterator[str]:
"""Post a chat message to the chatbot and stream progress updates (SSE).
@ -384,32 +384,20 @@ async def post_message_stream(
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
Yields:
Server-Sent Events formatted strings containing status updates and final response.
"""
logger.info(f"User {user.id} streaming message to thread {thread_id}")
logger.info(
f"User {user.id} streaming message to thread {thread_id} with {len(tool_ids)} tools"
)
try:
# Get user permissions
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
yield (
"data: "
+ json.dumps(
{
"type": "error",
"message": "User does not have permission to use any chatbot tools",
}
)
+ "\n\n"
)
return
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
# Get tools from registry (empty list if no tools)
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
@ -939,3 +927,96 @@ async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dic
logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}")
return tool_list
async def validate_and_get_tools_for_request(
*,
user_id: str,
requested_tool_ids: Optional[List[str]],
session: AsyncSession,
) -> List[str]:
"""Validate and get tool IDs for a chat request.
This function validates that the user has access to the requested tools.
If no tools are requested (None), it returns all tools the user has access to.
If an empty list is provided, it returns an empty list (no tools).
Args:
user_id: The user ID making the request.
requested_tool_ids: Optional list of tool UUIDs (id field) requested by the user.
- None: Use all tools the user has access to
- []: Use no tools at all
- ["uuid1", "uuid2"]: Use only the specified tools
session: The database session for querying.
Returns:
List of validated tool IDs (tool_id field, not UUID) that the user can use.
Raises:
PermissionError: If the user requests tools they don't have access to.
ValueError: If the user has no tools available when trying to use all tools.
"""
from modules.features.chatBot.database import Tool, UserToolMapping
import uuid
logger.info(f"Validating tools for user {user_id}")
# If empty list is explicitly provided, return empty list (no tools)
if requested_tool_ids is not None and len(requested_tool_ids) == 0:
logger.info(
f"Empty tool list requested, chatbot will run without tools for user {user_id}"
)
return []
# Get all tools the user has access to
stmt = (
select(Tool)
.join(UserToolMapping, Tool.id == UserToolMapping.tool_id)
.where(
UserToolMapping.user_id == user_id,
UserToolMapping.is_active == True,
Tool.is_active == True,
)
)
result = await session.execute(stmt)
user_tools = result.scalars().all()
# Create mappings for both UUID and tool_id
user_tool_ids_by_uuid = {str(tool.id): tool.tool_id for tool in user_tools}
user_tool_ids = set(user_tool_ids_by_uuid.values())
if not user_tool_ids:
logger.warning(f"User {user_id} has no tools available")
raise ValueError("User does not have access to any chatbot tools")
# If no specific tools requested (None), return all user's tools
if requested_tool_ids is None:
logger.info(
f"No specific tools requested, returning all {len(user_tool_ids)} tools for user {user_id}"
)
return list(user_tool_ids)
# Convert requested UUIDs to tool_ids and validate access
requested_tool_ids_result = []
unauthorized_uuids = []
for requested_uuid in requested_tool_ids:
if requested_uuid in user_tool_ids_by_uuid:
# User has access to this tool
requested_tool_ids_result.append(user_tool_ids_by_uuid[requested_uuid])
else:
# User doesn't have access to this tool
unauthorized_uuids.append(requested_uuid)
if unauthorized_uuids:
logger.warning(
f"User {user_id} requested unauthorized tool UUIDs: {unauthorized_uuids}"
)
raise PermissionError(
f"You do not have access to the following tools: {', '.join(unauthorized_uuids)}"
)
logger.info(
f"Validated {len(requested_tool_ids_result)} requested tools for user {user_id}"
)
return requested_tool_ids_result

View file

@ -62,6 +62,13 @@ async def post_chat_message_stream(
Returns Server-Sent Events (SSE) stream with status updates and final response.
"""
try:
# Validate and get tools for the request
tool_ids = await chat_service.validate_and_get_tools_for_request(
user_id=currentUser.id,
requested_tool_ids=message_request.tools,
session=session,
)
# Get or create thread using helper function
thread_id = await get_or_create_thread_for_user(
thread_id=message_request.thread_id,
@ -80,6 +87,7 @@ async def post_chat_message_stream(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
tool_ids=tool_ids,
),
media_type="text/event-stream",
headers={
@ -88,6 +96,18 @@ async def post_chat_message_stream(
},
)
except PermissionError as e:
logger.error(f"Permission error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
@ -114,6 +134,13 @@ async def post_chat_message(
For streaming updates, use the /message/stream endpoint instead.
"""
try:
# Validate and get tools for the request
tool_ids = await chat_service.validate_and_get_tools_for_request(
user_id=currentUser.id,
requested_tool_ids=message_request.tools,
session=session,
)
# Get or create thread using helper function
thread_id = await get_or_create_thread_for_user(
thread_id=message_request.thread_id,
@ -129,16 +156,23 @@ async def post_chat_message(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
tool_ids=tool_ids,
)
return response
except ValueError as e:
except PermissionError as e:
logger.error(f"Permission error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True