feat: allow users to specify tools when posting messages
This commit is contained in:
parent
2b5d7506d0
commit
ba1daa2d73
3 changed files with 144 additions and 25 deletions
|
|
@ -20,6 +20,10 @@ class ChatMessageRequest(BaseModel, ModelMixin):
|
|||
None, description="Thread ID (creates new thread if not provided)"
|
||||
)
|
||||
message: str = Field(..., description="User message content")
|
||||
tools: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of tool IDs to use. If not provided, all user's tools will be used",
|
||||
)
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel, ModelMixin):
|
||||
|
|
|
|||
|
|
@ -301,6 +301,7 @@ async def post_message(
|
|||
thread_id: str,
|
||||
message: str,
|
||||
user: User,
|
||||
tool_ids: List[str],
|
||||
) -> ChatMessageResponse:
|
||||
"""Post a chat message to the chatbot and return the response.
|
||||
|
||||
|
|
@ -308,21 +309,19 @@ async def post_message(
|
|||
thread_id: The unique identifier for the chat thread.
|
||||
message: The content of the chat message.
|
||||
user: The current user.
|
||||
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
|
||||
|
||||
Returns:
|
||||
The response containing the full chat message history and thread ID.
|
||||
"""
|
||||
logger.info(f"User {user.id} posted message to thread {thread_id}")
|
||||
|
||||
# Get user permissions
|
||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||
if not tool_ids:
|
||||
raise ValueError("User does not have permission to use any chatbot tools")
|
||||
logger.info(
|
||||
f"User {user.id} posted message to thread {thread_id} with {len(tool_ids)} tools"
|
||||
)
|
||||
|
||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||
|
||||
# Get tools from registry
|
||||
# Get tools from registry (empty list if no tools)
|
||||
registry = get_registry()
|
||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||
|
||||
|
|
@ -377,6 +376,7 @@ async def post_message_stream(
|
|||
thread_id: str,
|
||||
message: str,
|
||||
user: User,
|
||||
tool_ids: List[str],
|
||||
) -> AsyncIterator[str]:
|
||||
"""Post a chat message to the chatbot and stream progress updates (SSE).
|
||||
|
||||
|
|
@ -384,32 +384,20 @@ async def post_message_stream(
|
|||
thread_id: The unique identifier for the chat thread.
|
||||
message: The content of the chat message.
|
||||
user: The current user.
|
||||
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
|
||||
|
||||
Yields:
|
||||
Server-Sent Events formatted strings containing status updates and final response.
|
||||
"""
|
||||
logger.info(f"User {user.id} streaming message to thread {thread_id}")
|
||||
logger.info(
|
||||
f"User {user.id} streaming message to thread {thread_id} with {len(tool_ids)} tools"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get user permissions
|
||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||
if not tool_ids:
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "User does not have permission to use any chatbot tools",
|
||||
}
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
return
|
||||
|
||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||
|
||||
# Get tools from registry
|
||||
# Get tools from registry (empty list if no tools)
|
||||
registry = get_registry()
|
||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||
|
||||
|
|
@ -939,3 +927,96 @@ async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dic
|
|||
|
||||
logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}")
|
||||
return tool_list
|
||||
|
||||
|
||||
async def validate_and_get_tools_for_request(
|
||||
*,
|
||||
user_id: str,
|
||||
requested_tool_ids: Optional[List[str]],
|
||||
session: AsyncSession,
|
||||
) -> List[str]:
|
||||
"""Validate and get tool IDs for a chat request.
|
||||
|
||||
This function validates that the user has access to the requested tools.
|
||||
If no tools are requested (None), it returns all tools the user has access to.
|
||||
If an empty list is provided, it returns an empty list (no tools).
|
||||
|
||||
Args:
|
||||
user_id: The user ID making the request.
|
||||
requested_tool_ids: Optional list of tool UUIDs (id field) requested by the user.
|
||||
- None: Use all tools the user has access to
|
||||
- []: Use no tools at all
|
||||
- ["uuid1", "uuid2"]: Use only the specified tools
|
||||
session: The database session for querying.
|
||||
|
||||
Returns:
|
||||
List of validated tool IDs (tool_id field, not UUID) that the user can use.
|
||||
|
||||
Raises:
|
||||
PermissionError: If the user requests tools they don't have access to.
|
||||
ValueError: If the user has no tools available when trying to use all tools.
|
||||
"""
|
||||
from modules.features.chatBot.database import Tool, UserToolMapping
|
||||
import uuid
|
||||
|
||||
logger.info(f"Validating tools for user {user_id}")
|
||||
|
||||
# If empty list is explicitly provided, return empty list (no tools)
|
||||
if requested_tool_ids is not None and len(requested_tool_ids) == 0:
|
||||
logger.info(
|
||||
f"Empty tool list requested, chatbot will run without tools for user {user_id}"
|
||||
)
|
||||
return []
|
||||
|
||||
# Get all tools the user has access to
|
||||
stmt = (
|
||||
select(Tool)
|
||||
.join(UserToolMapping, Tool.id == UserToolMapping.tool_id)
|
||||
.where(
|
||||
UserToolMapping.user_id == user_id,
|
||||
UserToolMapping.is_active == True,
|
||||
Tool.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
user_tools = result.scalars().all()
|
||||
|
||||
# Create mappings for both UUID and tool_id
|
||||
user_tool_ids_by_uuid = {str(tool.id): tool.tool_id for tool in user_tools}
|
||||
user_tool_ids = set(user_tool_ids_by_uuid.values())
|
||||
|
||||
if not user_tool_ids:
|
||||
logger.warning(f"User {user_id} has no tools available")
|
||||
raise ValueError("User does not have access to any chatbot tools")
|
||||
|
||||
# If no specific tools requested (None), return all user's tools
|
||||
if requested_tool_ids is None:
|
||||
logger.info(
|
||||
f"No specific tools requested, returning all {len(user_tool_ids)} tools for user {user_id}"
|
||||
)
|
||||
return list(user_tool_ids)
|
||||
|
||||
# Convert requested UUIDs to tool_ids and validate access
|
||||
requested_tool_ids_result = []
|
||||
unauthorized_uuids = []
|
||||
|
||||
for requested_uuid in requested_tool_ids:
|
||||
if requested_uuid in user_tool_ids_by_uuid:
|
||||
# User has access to this tool
|
||||
requested_tool_ids_result.append(user_tool_ids_by_uuid[requested_uuid])
|
||||
else:
|
||||
# User doesn't have access to this tool
|
||||
unauthorized_uuids.append(requested_uuid)
|
||||
|
||||
if unauthorized_uuids:
|
||||
logger.warning(
|
||||
f"User {user_id} requested unauthorized tool UUIDs: {unauthorized_uuids}"
|
||||
)
|
||||
raise PermissionError(
|
||||
f"You do not have access to the following tools: {', '.join(unauthorized_uuids)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Validated {len(requested_tool_ids_result)} requested tools for user {user_id}"
|
||||
)
|
||||
return requested_tool_ids_result
|
||||
|
|
|
|||
|
|
@ -62,6 +62,13 @@ async def post_chat_message_stream(
|
|||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
|
|
@ -80,6 +87,7 @@ async def post_chat_message_stream(
|
|||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -88,6 +96,18 @@ async def post_chat_message_stream(
|
|||
},
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
|
|
@ -114,6 +134,13 @@ async def post_chat_message(
|
|||
For streaming updates, use the /message/stream endpoint instead.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
|
|
@ -129,16 +156,23 @@ async def post_chat_message(
|
|||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
|
|
|
|||
Loading…
Reference in a new issue