feat: allow users to specify tools when posting messages
This commit is contained in:
parent
2b5d7506d0
commit
ba1daa2d73
3 changed files with 144 additions and 25 deletions
|
|
@ -20,6 +20,10 @@ class ChatMessageRequest(BaseModel, ModelMixin):
|
||||||
None, description="Thread ID (creates new thread if not provided)"
|
None, description="Thread ID (creates new thread if not provided)"
|
||||||
)
|
)
|
||||||
message: str = Field(..., description="User message content")
|
message: str = Field(..., description="User message content")
|
||||||
|
tools: Optional[List[str]] = Field(
|
||||||
|
None,
|
||||||
|
description="List of tool IDs to use. If not provided, all user's tools will be used",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageResponse(BaseModel, ModelMixin):
|
class ChatMessageResponse(BaseModel, ModelMixin):
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,7 @@ async def post_message(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
user: User,
|
user: User,
|
||||||
|
tool_ids: List[str],
|
||||||
) -> ChatMessageResponse:
|
) -> ChatMessageResponse:
|
||||||
"""Post a chat message to the chatbot and return the response.
|
"""Post a chat message to the chatbot and return the response.
|
||||||
|
|
||||||
|
|
@ -308,21 +309,19 @@ async def post_message(
|
||||||
thread_id: The unique identifier for the chat thread.
|
thread_id: The unique identifier for the chat thread.
|
||||||
message: The content of the chat message.
|
message: The content of the chat message.
|
||||||
user: The current user.
|
user: The current user.
|
||||||
|
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response containing the full chat message history and thread ID.
|
The response containing the full chat message history and thread ID.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {user.id} posted message to thread {thread_id}")
|
logger.info(
|
||||||
|
f"User {user.id} posted message to thread {thread_id} with {len(tool_ids)} tools"
|
||||||
# Get user permissions
|
)
|
||||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
|
||||||
if not tool_ids:
|
|
||||||
raise ValueError("User does not have permission to use any chatbot tools")
|
|
||||||
|
|
||||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||||
|
|
||||||
# Get tools from registry
|
# Get tools from registry (empty list if no tools)
|
||||||
registry = get_registry()
|
registry = get_registry()
|
||||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||||
|
|
||||||
|
|
@ -377,6 +376,7 @@ async def post_message_stream(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
user: User,
|
user: User,
|
||||||
|
tool_ids: List[str],
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
"""Post a chat message to the chatbot and stream progress updates (SSE).
|
"""Post a chat message to the chatbot and stream progress updates (SSE).
|
||||||
|
|
||||||
|
|
@ -384,32 +384,20 @@ async def post_message_stream(
|
||||||
thread_id: The unique identifier for the chat thread.
|
thread_id: The unique identifier for the chat thread.
|
||||||
message: The content of the chat message.
|
message: The content of the chat message.
|
||||||
user: The current user.
|
user: The current user.
|
||||||
|
tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Server-Sent Events formatted strings containing status updates and final response.
|
Server-Sent Events formatted strings containing status updates and final response.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {user.id} streaming message to thread {thread_id}")
|
logger.info(
|
||||||
|
f"User {user.id} streaming message to thread {thread_id} with {len(tool_ids)} tools"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get user permissions
|
|
||||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
|
||||||
if not tool_ids:
|
|
||||||
yield (
|
|
||||||
"data: "
|
|
||||||
+ json.dumps(
|
|
||||||
{
|
|
||||||
"type": "error",
|
|
||||||
"message": "User does not have permission to use any chatbot tools",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n\n"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||||
|
|
||||||
# Get tools from registry
|
# Get tools from registry (empty list if no tools)
|
||||||
registry = get_registry()
|
registry = get_registry()
|
||||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||||
|
|
||||||
|
|
@ -939,3 +927,96 @@ async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dic
|
||||||
|
|
||||||
logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}")
|
logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}")
|
||||||
return tool_list
|
return tool_list
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_and_get_tools_for_request(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
requested_tool_ids: Optional[List[str]],
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Validate and get tool IDs for a chat request.
|
||||||
|
|
||||||
|
This function validates that the user has access to the requested tools.
|
||||||
|
If no tools are requested (None), it returns all tools the user has access to.
|
||||||
|
If an empty list is provided, it returns an empty list (no tools).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID making the request.
|
||||||
|
requested_tool_ids: Optional list of tool UUIDs (id field) requested by the user.
|
||||||
|
- None: Use all tools the user has access to
|
||||||
|
- []: Use no tools at all
|
||||||
|
- ["uuid1", "uuid2"]: Use only the specified tools
|
||||||
|
session: The database session for querying.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of validated tool IDs (tool_id field, not UUID) that the user can use.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If the user requests tools they don't have access to.
|
||||||
|
ValueError: If the user has no tools available when trying to use all tools.
|
||||||
|
"""
|
||||||
|
from modules.features.chatBot.database import Tool, UserToolMapping
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger.info(f"Validating tools for user {user_id}")
|
||||||
|
|
||||||
|
# If empty list is explicitly provided, return empty list (no tools)
|
||||||
|
if requested_tool_ids is not None and len(requested_tool_ids) == 0:
|
||||||
|
logger.info(
|
||||||
|
f"Empty tool list requested, chatbot will run without tools for user {user_id}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get all tools the user has access to
|
||||||
|
stmt = (
|
||||||
|
select(Tool)
|
||||||
|
.join(UserToolMapping, Tool.id == UserToolMapping.tool_id)
|
||||||
|
.where(
|
||||||
|
UserToolMapping.user_id == user_id,
|
||||||
|
UserToolMapping.is_active == True,
|
||||||
|
Tool.is_active == True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
user_tools = result.scalars().all()
|
||||||
|
|
||||||
|
# Create mappings for both UUID and tool_id
|
||||||
|
user_tool_ids_by_uuid = {str(tool.id): tool.tool_id for tool in user_tools}
|
||||||
|
user_tool_ids = set(user_tool_ids_by_uuid.values())
|
||||||
|
|
||||||
|
if not user_tool_ids:
|
||||||
|
logger.warning(f"User {user_id} has no tools available")
|
||||||
|
raise ValueError("User does not have access to any chatbot tools")
|
||||||
|
|
||||||
|
# If no specific tools requested (None), return all user's tools
|
||||||
|
if requested_tool_ids is None:
|
||||||
|
logger.info(
|
||||||
|
f"No specific tools requested, returning all {len(user_tool_ids)} tools for user {user_id}"
|
||||||
|
)
|
||||||
|
return list(user_tool_ids)
|
||||||
|
|
||||||
|
# Convert requested UUIDs to tool_ids and validate access
|
||||||
|
requested_tool_ids_result = []
|
||||||
|
unauthorized_uuids = []
|
||||||
|
|
||||||
|
for requested_uuid in requested_tool_ids:
|
||||||
|
if requested_uuid in user_tool_ids_by_uuid:
|
||||||
|
# User has access to this tool
|
||||||
|
requested_tool_ids_result.append(user_tool_ids_by_uuid[requested_uuid])
|
||||||
|
else:
|
||||||
|
# User doesn't have access to this tool
|
||||||
|
unauthorized_uuids.append(requested_uuid)
|
||||||
|
|
||||||
|
if unauthorized_uuids:
|
||||||
|
logger.warning(
|
||||||
|
f"User {user_id} requested unauthorized tool UUIDs: {unauthorized_uuids}"
|
||||||
|
)
|
||||||
|
raise PermissionError(
|
||||||
|
f"You do not have access to the following tools: {', '.join(unauthorized_uuids)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Validated {len(requested_tool_ids_result)} requested tools for user {user_id}"
|
||||||
|
)
|
||||||
|
return requested_tool_ids_result
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,13 @@ async def post_chat_message_stream(
|
||||||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate and get tools for the request
|
||||||
|
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||||
|
user_id=currentUser.id,
|
||||||
|
requested_tool_ids=message_request.tools,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
# Get or create thread using helper function
|
# Get or create thread using helper function
|
||||||
thread_id = await get_or_create_thread_for_user(
|
thread_id = await get_or_create_thread_for_user(
|
||||||
thread_id=message_request.thread_id,
|
thread_id=message_request.thread_id,
|
||||||
|
|
@ -80,6 +87,7 @@ async def post_chat_message_stream(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
message=message_request.message,
|
message=message_request.message,
|
||||||
user=currentUser,
|
user=currentUser,
|
||||||
|
tool_ids=tool_ids,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -88,6 +96,18 @@ async def post_chat_message_stream(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(e) or "Permission denied",
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(e) or "Permission denied",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||||
|
|
@ -114,6 +134,13 @@ async def post_chat_message(
|
||||||
For streaming updates, use the /message/stream endpoint instead.
|
For streaming updates, use the /message/stream endpoint instead.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate and get tools for the request
|
||||||
|
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||||
|
user_id=currentUser.id,
|
||||||
|
requested_tool_ids=message_request.tools,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
# Get or create thread using helper function
|
# Get or create thread using helper function
|
||||||
thread_id = await get_or_create_thread_for_user(
|
thread_id = await get_or_create_thread_for_user(
|
||||||
thread_id=message_request.thread_id,
|
thread_id=message_request.thread_id,
|
||||||
|
|
@ -129,16 +156,23 @@ async def post_chat_message(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
message=message_request.message,
|
message=message_request.message,
|
||||||
user=currentUser,
|
user=currentUser,
|
||||||
|
tool_ids=tool_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except ValueError as e:
|
except PermissionError as e:
|
||||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=str(e) or "Permission denied",
|
detail=str(e) or "Permission denied",
|
||||||
)
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(e) or "Permission denied",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue