removed chatbot on langgraph platform
This commit is contained in:
parent
711b8bc50d
commit
f05d958213
20 changed files with 7 additions and 3922 deletions
2
app.py
2
app.py
|
|
@ -422,5 +422,3 @@ app.include_router(voiceGoogleRouter)
|
||||||
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
|
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
|
||||||
app.include_router(adminSecurityRouter)
|
app.include_router(adminSecurityRouter)
|
||||||
|
|
||||||
# from modules.routes.routeChatbot import router as chatbotRouter
|
|
||||||
# app.include_router(chatbotRouter)
|
|
||||||
|
|
|
||||||
|
|
@ -1,216 +0,0 @@
|
||||||
"""Chatbot API models for request/response handling."""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
|
||||||
|
|
||||||
|
|
||||||
# Chatbot API Models
|
|
||||||
class MessageItem(BaseModel, ModelMixin):
|
|
||||||
"""Individual message in a thread"""
|
|
||||||
|
|
||||||
role: str = Field(..., description="Message role (user or assistant)")
|
|
||||||
content: str = Field(..., description="Message content")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageRequest(BaseModel, ModelMixin):
|
|
||||||
"""Request model for posting a chat message"""
|
|
||||||
|
|
||||||
thread_id: Optional[str] = Field(
|
|
||||||
None, description="Thread ID (creates new thread if not provided)"
|
|
||||||
)
|
|
||||||
message: str = Field(..., description="User message content")
|
|
||||||
tools: Optional[List[str]] = Field(
|
|
||||||
None,
|
|
||||||
description="List of tool IDs to use. If not provided, all user's tools will be used",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageResponse(BaseModel, ModelMixin):
|
|
||||||
"""Response model for posting a chat message"""
|
|
||||||
|
|
||||||
thread_id: str = Field(..., description="Thread ID")
|
|
||||||
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadSummary(BaseModel, ModelMixin):
|
|
||||||
"""Summary of a chat thread for list view"""
|
|
||||||
|
|
||||||
thread_id: str = Field(..., description="Thread ID")
|
|
||||||
thread_name: str = Field(..., description="Thread name")
|
|
||||||
date_created: float = Field(..., description="Thread creation timestamp")
|
|
||||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadListResponse(BaseModel, ModelMixin):
|
|
||||||
"""Response model for listing all threads"""
|
|
||||||
|
|
||||||
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadDetail(BaseModel, ModelMixin):
|
|
||||||
"""Detailed view of a single thread"""
|
|
||||||
|
|
||||||
thread_id: str = Field(..., description="Thread ID")
|
|
||||||
date_created: float = Field(..., description="Thread creation timestamp")
|
|
||||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
|
||||||
messages: List[MessageItem] = Field(
|
|
||||||
..., description="All messages in chronological order"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RenameThreadRequest(BaseModel, ModelMixin):
|
|
||||||
"""Request model for renaming a thread"""
|
|
||||||
|
|
||||||
new_name: str = Field(..., description="New name for the thread")
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteResponse(BaseModel, ModelMixin):
|
|
||||||
"""Response model for delete operations"""
|
|
||||||
|
|
||||||
message: str = Field(..., description="Confirmation message")
|
|
||||||
thread_id: str = Field(..., description="Deleted thread ID")
|
|
||||||
|
|
||||||
|
|
||||||
# Tool Management Models
|
|
||||||
class ToolInfo(BaseModel, ModelMixin):
|
|
||||||
"""Information about a chatbot tool"""
|
|
||||||
|
|
||||||
id: str = Field(..., description="Tool UUID")
|
|
||||||
tool_id: str = Field(
|
|
||||||
..., description="Tool identifier (e.g., 'shared.tavily_search')"
|
|
||||||
)
|
|
||||||
name: str = Field(..., description="Tool function name")
|
|
||||||
label: str = Field(..., description="Display label for the tool")
|
|
||||||
category: str = Field(..., description="Tool category (shared or customer)")
|
|
||||||
description: str = Field(..., description="Tool description")
|
|
||||||
is_active: bool = Field(..., description="Whether the tool is active")
|
|
||||||
date_created: float = Field(..., description="Creation timestamp")
|
|
||||||
date_updated: float = Field(..., description="Last update timestamp")
|
|
||||||
|
|
||||||
|
|
||||||
class ToolListResponse(BaseModel, ModelMixin):
|
|
||||||
"""Response model for listing all tools"""
|
|
||||||
|
|
||||||
tools: List[ToolInfo] = Field(..., description="List of available tools")
|
|
||||||
|
|
||||||
|
|
||||||
class GrantToolRequest(BaseModel, ModelMixin):
|
|
||||||
"""Request model for granting a tool to a user"""
|
|
||||||
|
|
||||||
user_id: str = Field(..., description="User ID to grant the tool to")
|
|
||||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
|
||||||
|
|
||||||
|
|
||||||
class GrantToolResponse(BaseModel, ModelMixin):
|
|
||||||
"""Response model after granting a tool"""
|
|
||||||
|
|
||||||
message: str = Field(..., description="Confirmation message")
|
|
||||||
user_id: str = Field(..., description="User ID")
|
|
||||||
tool_id: str = Field(..., description="Tool UUID")
|
|
||||||
|
|
||||||
|
|
||||||
class RevokeToolRequest(BaseModel, ModelMixin):
|
|
||||||
"""Request model for revoking a tool from a user"""
|
|
||||||
|
|
||||||
user_id: str = Field(..., description="User ID to revoke the tool from")
|
|
||||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
|
||||||
|
|
||||||
|
|
||||||
class RevokeToolResponse(BaseModel, ModelMixin):
|
|
||||||
"""Response model after revoking a tool"""
|
|
||||||
|
|
||||||
message: str = Field(..., description="Confirmation message")
|
|
||||||
user_id: str = Field(..., description="User ID")
|
|
||||||
tool_id: str = Field(..., description="Tool UUID")
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateToolRequest(BaseModel, ModelMixin):
|
|
||||||
"""Request model for updating a tool's label and description"""
|
|
||||||
|
|
||||||
label: Optional[str] = Field(None, description="New label for the tool")
|
|
||||||
description: Optional[str] = Field(None, description="New description for the tool")
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateToolResponse(BaseModel, ModelMixin):
|
|
||||||
"""Response model after updating a tool"""
|
|
||||||
|
|
||||||
message: str = Field(..., description="Confirmation message")
|
|
||||||
tool_id: str = Field(..., description="Tool UUID")
|
|
||||||
updated_fields: List[str] = Field(..., description="List of updated field names")
|
|
||||||
|
|
||||||
|
|
||||||
# Register model labels for internationalization
|
|
||||||
register_model_labels(
|
|
||||||
"MessageItem",
|
|
||||||
{"en": "Message Item", "fr": "Élément de message"},
|
|
||||||
{
|
|
||||||
"role": {"en": "Role", "fr": "Rôle"},
|
|
||||||
"content": {"en": "Content", "fr": "Contenu"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
register_model_labels(
|
|
||||||
"ChatMessageRequest",
|
|
||||||
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
|
|
||||||
{
|
|
||||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
|
||||||
"message": {"en": "Message", "fr": "Message"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
register_model_labels(
|
|
||||||
"ChatMessageResponse",
|
|
||||||
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
|
|
||||||
{
|
|
||||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
|
||||||
"messages": {"en": "Messages", "fr": "Messages"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
register_model_labels(
|
|
||||||
"ThreadSummary",
|
|
||||||
{"en": "Thread Summary", "fr": "Résumé du fil"},
|
|
||||||
{
|
|
||||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
|
||||||
"thread_name": {"en": "Thread Name", "fr": "Nom du fil"},
|
|
||||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
|
||||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
register_model_labels(
|
|
||||||
"ThreadListResponse",
|
|
||||||
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
|
|
||||||
{
|
|
||||||
"threads": {"en": "Threads", "fr": "Fils"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
register_model_labels(
|
|
||||||
"ThreadDetail",
|
|
||||||
{"en": "Thread Detail", "fr": "Détail du fil"},
|
|
||||||
{
|
|
||||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
|
||||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
|
||||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
|
||||||
"messages": {"en": "Messages", "fr": "Messages"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
register_model_labels(
|
|
||||||
"RenameThreadRequest",
|
|
||||||
{"en": "Rename Thread Request", "fr": "Demande de renommage de fil"},
|
|
||||||
{
|
|
||||||
"new_name": {"en": "New Name", "fr": "Nouveau nom"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
register_model_labels(
|
|
||||||
"DeleteResponse",
|
|
||||||
{"en": "Delete Response", "fr": "Réponse de suppression"},
|
|
||||||
{
|
|
||||||
"message": {"en": "Message", "fr": "Message"},
|
|
||||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
"""Contains all tools available for the chatbot to use."""
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
"""Tools that are shared between multiple customers go here."""
|
|
||||||
|
|
@ -1,208 +0,0 @@
|
||||||
"""Althaus Database Query Tool for LangGraph.
|
|
||||||
|
|
||||||
This tool provides database query capabilities for the Althaus database
|
|
||||||
via an external REST API. Only SELECT queries are allowed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
import re
|
|
||||||
from typing import Annotated
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def _mock_api_call(*, sql_query: str) -> dict:
|
|
||||||
"""Mock the external REST API call to Althaus database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sql_query: The SQL SELECT query to execute
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing the query results with columns and rows
|
|
||||||
"""
|
|
||||||
# Simulate network delay
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
# Mock response data based on common query patterns
|
|
||||||
if "users" in sql_query.lower():
|
|
||||||
return {
|
|
||||||
"columns": ["id", "username", "email", "created_at"],
|
|
||||||
"rows": [
|
|
||||||
[1, "john_doe", "john@example.com", "2024-01-15"],
|
|
||||||
[2, "jane_smith", "jane@example.com", "2024-02-20"],
|
|
||||||
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
|
|
||||||
],
|
|
||||||
"row_count": 3,
|
|
||||||
}
|
|
||||||
elif "products" in sql_query.lower():
|
|
||||||
return {
|
|
||||||
"columns": ["product_id", "name", "price", "stock"],
|
|
||||||
"rows": [
|
|
||||||
[101, "Widget A", 29.99, 150],
|
|
||||||
[102, "Widget B", 39.99, 75],
|
|
||||||
[103, "Widget C", 19.99, 200],
|
|
||||||
],
|
|
||||||
"row_count": 3,
|
|
||||||
}
|
|
||||||
elif "orders" in sql_query.lower():
|
|
||||||
return {
|
|
||||||
"columns": ["order_id", "customer_id", "total", "status"],
|
|
||||||
"rows": [
|
|
||||||
[5001, 1, 129.99, "completed"],
|
|
||||||
[5002, 2, 89.50, "pending"],
|
|
||||||
[5003, 1, 199.99, "shipped"],
|
|
||||||
],
|
|
||||||
"row_count": 3,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Generic response for other queries
|
|
||||||
return {
|
|
||||||
"columns": ["id", "value", "description"],
|
|
||||||
"rows": [
|
|
||||||
[1, "Sample 1", "First sample entry"],
|
|
||||||
[2, "Sample 2", "Second sample entry"],
|
|
||||||
],
|
|
||||||
"row_count": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
|
|
||||||
"""Validate that the query is a SELECT statement only.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sql_query: The SQL query to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
# Remove leading/trailing whitespace and convert to lowercase for checking
|
|
||||||
normalized_query = sql_query.strip().lower()
|
|
||||||
|
|
||||||
# Check if query starts with SELECT
|
|
||||||
if not normalized_query.startswith("select"):
|
|
||||||
return False, "Query must be a SELECT statement"
|
|
||||||
|
|
||||||
# Check for dangerous keywords that should not be in a SELECT query
|
|
||||||
dangerous_keywords = [
|
|
||||||
"insert",
|
|
||||||
"update",
|
|
||||||
"delete",
|
|
||||||
"drop",
|
|
||||||
"create",
|
|
||||||
"alter",
|
|
||||||
"truncate",
|
|
||||||
"grant",
|
|
||||||
"revoke",
|
|
||||||
"exec",
|
|
||||||
"execute",
|
|
||||||
]
|
|
||||||
|
|
||||||
for keyword in dangerous_keywords:
|
|
||||||
# Use word boundary to match whole words only
|
|
||||||
if re.search(rf"\b{keyword}\b", normalized_query):
|
|
||||||
return False, f"Query contains forbidden keyword: {keyword.upper()}"
|
|
||||||
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
|
|
||||||
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
|
|
||||||
"""Format query results into a readable string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
columns: List of column names
|
|
||||||
rows: List of row data
|
|
||||||
row_count: Total number of rows
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string representation of the results
|
|
||||||
"""
|
|
||||||
if row_count == 0:
|
|
||||||
return "Query executed successfully but returned no results."
|
|
||||||
|
|
||||||
# Calculate column widths
|
|
||||||
col_widths = [len(str(col)) for col in columns]
|
|
||||||
for row in rows:
|
|
||||||
for i, cell in enumerate(row):
|
|
||||||
col_widths[i] = max(col_widths[i], len(str(cell)))
|
|
||||||
|
|
||||||
# Build header
|
|
||||||
header_parts = []
|
|
||||||
for col, width in zip(columns, col_widths):
|
|
||||||
header_parts.append(str(col).ljust(width))
|
|
||||||
header = " | ".join(header_parts)
|
|
||||||
separator = "-" * len(header)
|
|
||||||
|
|
||||||
# Build rows
|
|
||||||
row_lines = []
|
|
||||||
for row in rows:
|
|
||||||
row_parts = []
|
|
||||||
for cell, width in zip(row, col_widths):
|
|
||||||
row_parts.append(str(cell).ljust(width))
|
|
||||||
row_lines.append(" | ".join(row_parts))
|
|
||||||
|
|
||||||
# Combine all parts
|
|
||||||
result_parts = [
|
|
||||||
f"Query returned {row_count} row(s):\n",
|
|
||||||
header,
|
|
||||||
separator,
|
|
||||||
"\n".join(row_lines),
|
|
||||||
]
|
|
||||||
|
|
||||||
return "\n".join(result_parts)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def query_althaus_database(
|
|
||||||
sql_query: Annotated[
|
|
||||||
str, "The SQL SELECT query to execute against the Althaus database"
|
|
||||||
],
|
|
||||||
) -> str:
|
|
||||||
"""Execute a SELECT query against the Althaus database via REST API.
|
|
||||||
|
|
||||||
Use this tool to query data from the Althaus database. Only SELECT statements
|
|
||||||
are allowed for security reasons. The query will be forwarded to an external
|
|
||||||
REST API and the results will be returned in a formatted table.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A formatted string containing the query results with columns and rows
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Validate the query
|
|
||||||
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
|
|
||||||
if not is_valid:
|
|
||||||
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
|
|
||||||
return f"Error: {error_msg}"
|
|
||||||
|
|
||||||
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
|
|
||||||
|
|
||||||
# Mock the external REST API call
|
|
||||||
# In production, this would be replaced with actual REST API call:
|
|
||||||
# response = await httpx.AsyncClient().post(
|
|
||||||
# "https://api.althaus.example.com/query",
|
|
||||||
# json={"query": sql_query},
|
|
||||||
# headers={"Authorization": f"Bearer {api_key}"}
|
|
||||||
# )
|
|
||||||
# result = response.json()
|
|
||||||
|
|
||||||
result = await _mock_api_call(sql_query=sql_query)
|
|
||||||
|
|
||||||
# Format and return results
|
|
||||||
formatted_output = _format_results(
|
|
||||||
columns=result["columns"],
|
|
||||||
rows=result["rows"],
|
|
||||||
row_count=result["row_count"],
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Query completed successfully, returned {result['row_count']} row(s)"
|
|
||||||
)
|
|
||||||
return formatted_output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in query_althaus_database tool: {str(e)}")
|
|
||||||
return f"Error executing query: {str(e)}"
|
|
||||||
|
|
@ -1,362 +0,0 @@
|
||||||
"""Power BI Query Tool for LangGraph.
|
|
||||||
|
|
||||||
This tool provides DAX query capabilities for Power BI datasets
|
|
||||||
via the Power BI REST API. Only read-only queries are allowed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import functools
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import anyio
|
|
||||||
import httpx
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from msal import ConfidentialClientApplication, SerializableTokenCache
|
|
||||||
|
|
||||||
from modules.shared.configuration import APP_CONFIG
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Configuration constants - encapsulated in this file
|
|
||||||
POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "")
|
|
||||||
POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "")
|
|
||||||
POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "")
|
|
||||||
POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "")
|
|
||||||
POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg"
|
|
||||||
POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com"
|
|
||||||
POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"]
|
|
||||||
|
|
||||||
# Limit results to prevent excessive context usage
|
|
||||||
MAX_ROWS_LIMIT = 100
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_environment() -> tuple[bool, str]:
|
|
||||||
"""Validate that all required environment variables are set.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
missing = []
|
|
||||||
if not POWERBI_DATASET_ID:
|
|
||||||
missing.append("POWERBI_DATASET_ID")
|
|
||||||
if not POWERBI_CLIENT_ID:
|
|
||||||
missing.append("POWERBI_CLIENT_ID")
|
|
||||||
if not POWERBI_CLIENT_SECRET:
|
|
||||||
missing.append("POWERBI_CLIENT_SECRET")
|
|
||||||
if not POWERBI_TENANT_ID:
|
|
||||||
missing.append("POWERBI_TENANT_ID")
|
|
||||||
|
|
||||||
if missing:
|
|
||||||
return False, f"Missing required environment variables: {', '.join(missing)}"
|
|
||||||
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]:
|
|
||||||
"""Validate that the query is a valid DAX query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dax_query: The DAX query to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
# Remove leading/trailing whitespace
|
|
||||||
normalized_query = dax_query.strip()
|
|
||||||
|
|
||||||
if not normalized_query:
|
|
||||||
return False, "Query cannot be empty"
|
|
||||||
|
|
||||||
# DAX queries typically start with EVALUATE, DEFINE, or are table expressions
|
|
||||||
# We'll be lenient and just check it's not trying to do something dangerous
|
|
||||||
# DAX is read-only by nature, but we validate structure
|
|
||||||
|
|
||||||
# Check for minimum length
|
|
||||||
if len(normalized_query) < 5:
|
|
||||||
return False, "Query is too short to be valid"
|
|
||||||
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
|
|
||||||
def _get_access_token_sync(
|
|
||||||
*,
|
|
||||||
tenant_id: str,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: str,
|
|
||||||
authority_base: str = POWERBI_AUTHORITY_BASE,
|
|
||||||
cache: SerializableTokenCache | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Get Power BI access token using MSAL (synchronous).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: Azure AD tenant ID
|
|
||||||
client_id: Application client ID
|
|
||||||
client_secret: Application client secret
|
|
||||||
authority_base: Azure AD authority base URL
|
|
||||||
cache: Optional token cache for reuse
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Access token string
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If token acquisition fails
|
|
||||||
"""
|
|
||||||
authority = f"{authority_base}/{tenant_id}"
|
|
||||||
|
|
||||||
app = ConfidentialClientApplication(
|
|
||||||
client_id=client_id,
|
|
||||||
authority=authority,
|
|
||||||
client_credential=client_secret,
|
|
||||||
token_cache=cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try cache first; fall back to client credentials
|
|
||||||
result = app.acquire_token_silent(
|
|
||||||
POWERBI_SCOPE, account=None
|
|
||||||
) or app.acquire_token_for_client(scopes=POWERBI_SCOPE)
|
|
||||||
|
|
||||||
if "access_token" not in result:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result["access_token"]
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_access_token_async(
|
|
||||||
*,
|
|
||||||
tenant_id: str,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
"""Get Power BI access token using MSAL (asynchronous).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: Azure AD tenant ID
|
|
||||||
client_id: Application client ID
|
|
||||||
client_secret: Application client secret
|
|
||||||
**kwargs: Additional arguments for _get_access_token_sync
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Access token string
|
|
||||||
"""
|
|
||||||
# Create a partial function with arguments pre-filled
|
|
||||||
func = functools.partial(
|
|
||||||
_get_access_token_sync,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
# Offload the blocking MSAL HTTP call to a worker thread
|
|
||||||
return await anyio.to_thread.run_sync(func)
|
|
||||||
|
|
||||||
|
|
||||||
async def _execute_dax_query(
|
|
||||||
*, dax_query: str, dataset_id: str, access_token: str
|
|
||||||
) -> dict:
|
|
||||||
"""Execute a DAX query against Power BI dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dax_query: The DAX query to execute
|
|
||||||
dataset_id: Power BI dataset ID
|
|
||||||
access_token: Access token for authentication
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing query results
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If query execution fails
|
|
||||||
"""
|
|
||||||
url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries"
|
|
||||||
|
|
||||||
body = {
|
|
||||||
"queries": [{"query": dax_query}],
|
|
||||||
"serializerSettings": {"includeNulls": True},
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
||||||
resp = await client.post(url, headers=headers, json=body)
|
|
||||||
|
|
||||||
if resp.status_code != 200:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = resp.json()
|
|
||||||
|
|
||||||
try:
|
|
||||||
rows = payload["results"][0]["tables"][0]["rows"]
|
|
||||||
except (KeyError, IndexError) as e:
|
|
||||||
raise RuntimeError("Unexpected executeQueries response structure") from e
|
|
||||||
|
|
||||||
# Extract column names from the first row if available
|
|
||||||
if rows:
|
|
||||||
columns = list(rows[0].keys())
|
|
||||||
else:
|
|
||||||
columns = []
|
|
||||||
|
|
||||||
return {"columns": columns, "rows": rows}
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_table_qualifier(*, column_name: str) -> str:
|
|
||||||
"""Strip table qualifier from column name.
|
|
||||||
|
|
||||||
Power BI often returns columns as 'Table[Column]'. This strips to 'Column'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
column_name: The column name to process
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Processed column name
|
|
||||||
"""
|
|
||||||
if "[" in column_name and column_name.endswith("]"):
|
|
||||||
return column_name.split("[", 1)[1][:-1]
|
|
||||||
return column_name
|
|
||||||
|
|
||||||
|
|
||||||
def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str:
|
|
||||||
"""Format query results into a readable string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
columns: List of column names
|
|
||||||
rows: List of row data (as dictionaries)
|
|
||||||
max_rows: Maximum number of rows to display
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string representation of the results
|
|
||||||
"""
|
|
||||||
total_rows = len(rows)
|
|
||||||
|
|
||||||
if total_rows == 0:
|
|
||||||
return "Query executed successfully but returned no results."
|
|
||||||
|
|
||||||
# Strip table qualifiers from column names
|
|
||||||
clean_columns = [_strip_table_qualifier(column_name=col) for col in columns]
|
|
||||||
|
|
||||||
# Limit rows to max_rows
|
|
||||||
display_rows = rows[:max_rows]
|
|
||||||
truncated = total_rows > max_rows
|
|
||||||
|
|
||||||
# Calculate column widths
|
|
||||||
col_widths = [len(str(col)) for col in clean_columns]
|
|
||||||
for row in display_rows:
|
|
||||||
for i, col in enumerate(columns):
|
|
||||||
value = row.get(col, "")
|
|
||||||
col_widths[i] = max(col_widths[i], len(str(value)))
|
|
||||||
|
|
||||||
# Build header
|
|
||||||
header_parts = []
|
|
||||||
for col, width in zip(clean_columns, col_widths):
|
|
||||||
header_parts.append(str(col).ljust(width))
|
|
||||||
header = " | ".join(header_parts)
|
|
||||||
separator = "-" * len(header)
|
|
||||||
|
|
||||||
# Build rows
|
|
||||||
row_lines = []
|
|
||||||
for row in display_rows:
|
|
||||||
row_parts = []
|
|
||||||
for col, width in zip(columns, col_widths):
|
|
||||||
value = row.get(col, "")
|
|
||||||
row_parts.append(str(value).ljust(width))
|
|
||||||
row_lines.append(" | ".join(row_parts))
|
|
||||||
|
|
||||||
# Combine all parts
|
|
||||||
result_parts = [
|
|
||||||
f"Query returned {total_rows} row(s):",
|
|
||||||
]
|
|
||||||
|
|
||||||
if truncated:
|
|
||||||
result_parts.append(
|
|
||||||
f"(Results limited to {max_rows} rows for context efficiency)\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result_parts.append("")
|
|
||||||
|
|
||||||
result_parts.extend([header, separator, "\n".join(row_lines)])
|
|
||||||
|
|
||||||
return "\n".join(result_parts)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def query_powerbi_data(
|
|
||||||
dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"],
|
|
||||||
) -> str:
|
|
||||||
"""Execute a DAX query against the Power BI dataset to access warehouse inventory data.
|
|
||||||
|
|
||||||
This tool provides access to a Power BI table called 'data_full' which contains
|
|
||||||
articles available in the warehouse of the user. Use DAX (Data Analysis Expressions)
|
|
||||||
queries to retrieve and analyze this inventory data.
|
|
||||||
|
|
||||||
Available table:
|
|
||||||
- 'data_full': Contains warehouse inventory articles and their details
|
|
||||||
|
|
||||||
Common query patterns:
|
|
||||||
- View all data: EVALUATE 'data_full'
|
|
||||||
- With filter: EVALUATE FILTER('data_full', [Column] = "Value")
|
|
||||||
- Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC)
|
|
||||||
- Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2]))
|
|
||||||
|
|
||||||
Results are limited to 100 rows maximum for efficiency.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A formatted string containing the query results with columns and rows
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Validate environment configuration
|
|
||||||
is_valid_env, error_msg = _validate_environment()
|
|
||||||
if not is_valid_env:
|
|
||||||
logger.error(f"Environment validation failed: {error_msg}")
|
|
||||||
return f"Configuration Error: {error_msg}"
|
|
||||||
|
|
||||||
# Validate the query
|
|
||||||
is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query)
|
|
||||||
if not is_valid_query:
|
|
||||||
logger.warning(f"Invalid query attempt: {dax_query[:100]}...")
|
|
||||||
return f"Query Validation Error: {error_msg}"
|
|
||||||
|
|
||||||
logger.info(f"Executing Power BI query: {dax_query[:100]}...")
|
|
||||||
|
|
||||||
# Get access token
|
|
||||||
access_token = await _get_access_token_async(
|
|
||||||
tenant_id=POWERBI_TENANT_ID,
|
|
||||||
client_id=POWERBI_CLIENT_ID,
|
|
||||||
client_secret=POWERBI_CLIENT_SECRET,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the query
|
|
||||||
result = await _execute_dax_query(
|
|
||||||
dax_query=dax_query,
|
|
||||||
dataset_id=POWERBI_DATASET_ID,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format and return results
|
|
||||||
formatted_output = _format_results(
|
|
||||||
columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Query completed successfully, returned {len(result['rows'])} row(s)"
|
|
||||||
)
|
|
||||||
return formatted_output
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}")
|
|
||||||
return f"Error executing query: {str(e)}"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}")
|
|
||||||
return f"Unexpected error: {str(e)}"
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
"""Shared tools available across all chatbot implementations."""
|
|
||||||
|
|
||||||
from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import (
|
|
||||||
tavily_search,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["tavily_search"]
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
||||||
"""Tool for sending streaming status updates to users."""
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def send_streaming_message(message: str) -> str:
|
|
||||||
"""Send a streaming message to the user to provide updates during processing.
|
|
||||||
|
|
||||||
Use this tool to send short status updates to the user while you are working
|
|
||||||
on their request. This helps keep the user informed about what you are doing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: A short message describing what you are currently doing.
|
|
||||||
Examples: "Searching database for relevant information..."
|
|
||||||
"Analyzing search results..."
|
|
||||||
"Processing your request..."
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A confirmation that the message was sent.
|
|
||||||
"""
|
|
||||||
# This tool doesn't actually do anything - it's just for the AI to signal
|
|
||||||
# what it's doing to the frontend via the tool call mechanism
|
|
||||||
return f"Status update sent: {message}"
|
|
||||||
|
|
@ -1,55 +0,0 @@
|
||||||
"""Tavily Search Tool for LangGraph.
|
|
||||||
|
|
||||||
This tool provides web search capabilities using the Tavily API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Annotated
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from modules.connectors.connectorAiTavily import ConnectorWeb
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def tavily_search(
|
|
||||||
query: Annotated[str, "The search query to look up on the web"],
|
|
||||||
) -> str:
|
|
||||||
"""Search the web using Tavily API.
|
|
||||||
|
|
||||||
Use this tool to search for current information, news, or any web content.
|
|
||||||
The tool returns relevant search results including titles and URLs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The search query string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A formatted string containing search results with titles and URLs
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Create connector instance
|
|
||||||
connector = await ConnectorWeb.create()
|
|
||||||
|
|
||||||
# Perform search with default parameters
|
|
||||||
results = await connector._search(
|
|
||||||
query=query,
|
|
||||||
max_results=5,
|
|
||||||
search_depth="basic",
|
|
||||||
include_answer=True,
|
|
||||||
include_raw_content=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format results
|
|
||||||
if not results:
|
|
||||||
return f"No results found for query: {query}"
|
|
||||||
|
|
||||||
formatted_results = [f"Search results for '{query}':\n"]
|
|
||||||
for i, result in enumerate(results, 1):
|
|
||||||
formatted_results.append(f"{i}. {result.title}")
|
|
||||||
formatted_results.append(f" URL: {result.url}\n")
|
|
||||||
|
|
||||||
return "\n".join(formatted_results)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in tavily_search tool: {str(e)}")
|
|
||||||
return f"Error performing search: {str(e)}"
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
"""Domain logic for chatbot functionality."""
|
|
||||||
|
|
@ -1,301 +0,0 @@
|
||||||
"""Chatbot domain logic with LangGraph integration."""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Annotated, AsyncIterator, Any
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from langchain_core.messages import (
|
|
||||||
BaseMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
trim_messages,
|
|
||||||
)
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
|
||||||
|
|
||||||
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
|
|
||||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
|
||||||
from modules.shared.configuration import APP_CONFIG
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatState(BaseModel):
|
|
||||||
"""Represents the state of a chat session."""
|
|
||||||
|
|
||||||
messages: Annotated[list[BaseMessage], add_messages]
|
|
||||||
|
|
||||||
|
|
||||||
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
|
||||||
"""Map permission model names to LangChain ChatAnthropic models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: The model name from permissions (e.g., "claude_4_5")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured ChatAnthropic instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model name is not supported
|
|
||||||
"""
|
|
||||||
# Model name mapping
|
|
||||||
model_mapping = {
|
|
||||||
"claude_4_5": "claude-sonnet-4-5",
|
|
||||||
# Add more mappings as needed
|
|
||||||
}
|
|
||||||
|
|
||||||
anthropic_model = model_mapping.get(model_name)
|
|
||||||
if not anthropic_model:
|
|
||||||
logger.warning(
|
|
||||||
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
|
|
||||||
)
|
|
||||||
anthropic_model = "claude-4-5-sonnet"
|
|
||||||
|
|
||||||
return ChatAnthropic(
|
|
||||||
model=anthropic_model,
|
|
||||||
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
|
|
||||||
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
|
|
||||||
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Chatbot:
|
|
||||||
"""Represents a chatbot with LangGraph integration."""
|
|
||||||
|
|
||||||
model: Any
|
|
||||||
memory: Any
|
|
||||||
app: Any = None
|
|
||||||
system_prompt: str = "You are a helpful assistant."
|
|
||||||
context_window_size: int = 100000
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
model: Any,
|
|
||||||
memory: Any,
|
|
||||||
system_prompt: str,
|
|
||||||
tools: list,
|
|
||||||
context_window_size: int = 100000,
|
|
||||||
) -> "Chatbot":
|
|
||||||
"""Factory method to create and configure a Chatbot instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The chat model to use.
|
|
||||||
memory: The chat memory checkpointer to use.
|
|
||||||
system_prompt: The system prompt to initialize the chatbot.
|
|
||||||
tools: List of LangChain tools the chatbot can use.
|
|
||||||
context_window_size: Maximum tokens for context window.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured Chatbot instance.
|
|
||||||
"""
|
|
||||||
instance = cls(
|
|
||||||
model=model,
|
|
||||||
memory=memory,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
context_window_size=context_window_size,
|
|
||||||
)
|
|
||||||
instance.app = instance._build_app(memory=memory, tools=tools)
|
|
||||||
return instance
|
|
||||||
|
|
||||||
def _build_app(
|
|
||||||
self, *, memory: Any, tools: list
|
|
||||||
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
|
||||||
"""Builds the chatbot application workflow using LangGraph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
memory: The chat memory checkpointer to use.
|
|
||||||
tools: The list of tools the chatbot can use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A compiled state graph representing the chatbot application.
|
|
||||||
"""
|
|
||||||
llm_with_tools = self.model.bind_tools(tools=tools)
|
|
||||||
|
|
||||||
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
|
||||||
"""Selects a window of messages that fit within the context window size.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msgs: The list of messages to select from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of messages that fit within the context window size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def approx_counter(items: list[BaseMessage]) -> int:
|
|
||||||
"""Approximate token counter for messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items: List of messages to count tokens for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Approximate number of tokens in the messages.
|
|
||||||
"""
|
|
||||||
return sum(len(getattr(m, "content", "") or "") for m in items)
|
|
||||||
|
|
||||||
return trim_messages(
|
|
||||||
msgs,
|
|
||||||
strategy="last",
|
|
||||||
token_counter=approx_counter,
|
|
||||||
max_tokens=self.context_window_size,
|
|
||||||
start_on="human",
|
|
||||||
end_on=("human", "tool"),
|
|
||||||
include_system=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def agent_node(state: ChatState) -> dict:
|
|
||||||
"""Agent node for the chatbot workflow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The current chat state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The updated chat state after processing.
|
|
||||||
"""
|
|
||||||
# Select the message window to fit in context (trim if needed)
|
|
||||||
window = select_window(state.messages)
|
|
||||||
|
|
||||||
# Ensure the system prompt is present at the start
|
|
||||||
if not window or not isinstance(window[0], SystemMessage):
|
|
||||||
window = [SystemMessage(content=self.system_prompt)] + window
|
|
||||||
|
|
||||||
# Call the LLM with tools
|
|
||||||
response = llm_with_tools.invoke(window)
|
|
||||||
|
|
||||||
# Return the new state
|
|
||||||
return {"messages": [response]}
|
|
||||||
|
|
||||||
def should_continue(state: ChatState) -> str:
|
|
||||||
"""Determines whether to continue the workflow or end it.
|
|
||||||
|
|
||||||
This conditional edge is called after the agent node to decide
|
|
||||||
whether to continue to the tools node (if the last message contains
|
|
||||||
tool calls) or to end the workflow (if no tool calls are present).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The current chat state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The next node to transition to ("tools" or END).
|
|
||||||
"""
|
|
||||||
# Get the last message
|
|
||||||
last_message = state.messages[-1]
|
|
||||||
|
|
||||||
# Check if the last message contains tool calls
|
|
||||||
# If so, continue to the tools node; otherwise, end the workflow
|
|
||||||
return "tools" if getattr(last_message, "tool_calls", None) else END
|
|
||||||
|
|
||||||
# Compose the workflow
|
|
||||||
workflow = StateGraph(ChatState)
|
|
||||||
workflow.add_node("agent", agent_node)
|
|
||||||
workflow.add_node("tools", ToolNode(tools=tools))
|
|
||||||
workflow.add_edge(START, "agent")
|
|
||||||
workflow.add_conditional_edges("agent", should_continue)
|
|
||||||
workflow.add_edge("tools", "agent")
|
|
||||||
return workflow.compile(checkpointer=memory)
|
|
||||||
|
|
||||||
async def chat(
|
|
||||||
self, *, message: str, chat_id: str = "default"
|
|
||||||
) -> list[BaseMessage]:
|
|
||||||
"""Processes a chat message and returns the chat history.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The user message to process.
|
|
||||||
chat_id: The chat thread ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The list of messages in the chat history.
|
|
||||||
"""
|
|
||||||
# Set the right thread ID for memory
|
|
||||||
config = {"configurable": {"thread_id": chat_id}}
|
|
||||||
|
|
||||||
# Single-turn chat (non-streaming)
|
|
||||||
result = await self.app.ainvoke(
|
|
||||||
{"messages": [HumanMessage(content=message)]}, config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract and return the messages from the result
|
|
||||||
return result["messages"]
|
|
||||||
|
|
||||||
async def stream_events(
|
|
||||||
self, *, message: str, chat_id: str = "default"
|
|
||||||
) -> AsyncIterator[dict]:
|
|
||||||
"""Stream UI-focused events using astream_events v2.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The user message to process.
|
|
||||||
chat_id: Logical thread identifier; forwarded in the runnable config so
|
|
||||||
memory and tools are scoped per thread.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
dict: One of:
|
|
||||||
- ``{"type": "status", "label": str}`` for short progress updates.
|
|
||||||
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
|
||||||
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
|
||||||
- ``{"type": "error", "message": str}`` if an exception occurs.
|
|
||||||
"""
|
|
||||||
# Thread-aware config for LangGraph/LangChain
|
|
||||||
config = {"configurable": {"thread_id": chat_id}}
|
|
||||||
|
|
||||||
def _is_root(ev: dict) -> bool:
|
|
||||||
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
|
||||||
return not ev.get("parent_ids")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for event in self.app.astream_events(
|
|
||||||
{"messages": [HumanMessage(content=message)]},
|
|
||||||
config=config,
|
|
||||||
version="v2",
|
|
||||||
):
|
|
||||||
etype = event.get("event")
|
|
||||||
ename = event.get("name") or ""
|
|
||||||
edata = event.get("data") or {}
|
|
||||||
|
|
||||||
# Stream human-readable progress via the special send_streaming_message tool
|
|
||||||
if etype == "on_tool_start" and ename == "send_streaming_message":
|
|
||||||
tool_in = edata.get("input") or {}
|
|
||||||
msg = tool_in.get("message")
|
|
||||||
if isinstance(msg, str) and msg.strip():
|
|
||||||
yield {"type": "status", "label": msg.strip()}
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Emit the final payload when the root run finishes
|
|
||||||
if etype == "on_chain_end" and _is_root(event):
|
|
||||||
output_obj = edata.get("output")
|
|
||||||
|
|
||||||
# Extract message list from the graph's final output
|
|
||||||
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
|
||||||
output_obj=output_obj
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalize for the frontend (only user/assistant with text content)
|
|
||||||
chat_history_payload: list[dict] = []
|
|
||||||
for m in final_msgs:
|
|
||||||
if isinstance(m, BaseMessage):
|
|
||||||
d = ChatStreamingHelper.message_to_dict(msg=m)
|
|
||||||
elif isinstance(m, dict):
|
|
||||||
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
if d.get("role") in ("user", "assistant") and d.get("content"):
|
|
||||||
chat_history_payload.append(d)
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"type": "final",
|
|
||||||
"response": {
|
|
||||||
"thread": chat_id,
|
|
||||||
"chat_history": chat_history_payload,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
# Emit a single error envelope and end the stream
|
|
||||||
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
|
|
||||||
yield {"type": "error", "message": f"Error processing request: {exc}"}
|
|
||||||
|
|
@ -1,239 +0,0 @@
|
||||||
"""Streaming helper utilities for chat message processing and normalization."""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
|
||||||
|
|
||||||
from langchain_core.messages import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
Role = Literal["user", "assistant", "system", "tool"]
|
|
||||||
|
|
||||||
|
|
||||||
class ChatStreamingHelper:
|
|
||||||
"""Pure helper methods for streaming and message normalization.
|
|
||||||
|
|
||||||
This class provides static utility methods for converting between different
|
|
||||||
message formats, extracting content, and normalizing message structures
|
|
||||||
for streaming chat applications.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def role_from_message(*, msg: BaseMessage) -> Role:
|
|
||||||
"""Extract the role from a BaseMessage instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg: The BaseMessage instance to extract the role from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The role as a string literal: "user", "assistant", "system", or "tool".
|
|
||||||
Defaults to "assistant" if the message type is not recognized.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from langchain_core.messages import HumanMessage
|
|
||||||
>>> msg = HumanMessage(content="Hello")
|
|
||||||
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
|
||||||
'user'
|
|
||||||
"""
|
|
||||||
if isinstance(msg, HumanMessage):
|
|
||||||
return "user"
|
|
||||||
if isinstance(msg, AIMessage):
|
|
||||||
return "assistant"
|
|
||||||
if isinstance(msg, SystemMessage):
|
|
||||||
return "system"
|
|
||||||
if isinstance(msg, ToolMessage):
|
|
||||||
return "tool"
|
|
||||||
return getattr(msg, "role", "assistant")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def flatten_content(*, content: Any) -> str:
|
|
||||||
"""Convert complex content structures to plain text.
|
|
||||||
|
|
||||||
This method handles various content formats including strings, lists of
|
|
||||||
content parts, and dictionaries with text fields. It's designed to
|
|
||||||
normalize content from different message sources into a consistent
|
|
||||||
plain text format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The content to flatten. Can be:
|
|
||||||
- str: Returned as-is after stripping whitespace
|
|
||||||
- list: Each item processed and joined with newlines
|
|
||||||
- dict: Text extracted from "text" or "content" fields
|
|
||||||
- None: Returns empty string
|
|
||||||
- Any other type: Converted to string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The flattened content as a plain text string with whitespace stripped.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
|
||||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
|
||||||
'Hello
|
|
||||||
nworld'
|
|
||||||
|
|
||||||
>>> content = {"text": "Simple message"}
|
|
||||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
|
||||||
'Simple message'
|
|
||||||
"""
|
|
||||||
if content is None:
|
|
||||||
return ""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content.strip()
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: List[str] = []
|
|
||||||
for part in content:
|
|
||||||
if isinstance(part, dict):
|
|
||||||
if "text" in part and isinstance(part["text"], str):
|
|
||||||
parts.append(part["text"])
|
|
||||||
elif part.get("type") == "text" and isinstance(
|
|
||||||
part.get("text"), str
|
|
||||||
):
|
|
||||||
parts.append(part["text"])
|
|
||||||
elif "content" in part and isinstance(part["content"], str):
|
|
||||||
parts.append(part["content"])
|
|
||||||
else:
|
|
||||||
# Fallback for unknown dictionary structures
|
|
||||||
val = part.get("value")
|
|
||||||
if isinstance(val, str):
|
|
||||||
parts.append(val)
|
|
||||||
else:
|
|
||||||
parts.append(str(part))
|
|
||||||
return "\n".join(p.strip() for p in parts if p is not None)
|
|
||||||
if isinstance(content, dict):
|
|
||||||
if "text" in content and isinstance(content["text"], str):
|
|
||||||
return content["text"].strip()
|
|
||||||
if "content" in content and isinstance(content["content"], str):
|
|
||||||
return content["content"].strip()
|
|
||||||
return str(content).strip()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
|
||||||
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
|
||||||
|
|
||||||
This method normalizes BaseMessage instances into a consistent dictionary
|
|
||||||
format suitable for JSON serialization and streaming to clients.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg: The BaseMessage instance to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing:
|
|
||||||
- "role": The message role (user, assistant, system, tool)
|
|
||||||
- "content": The flattened message content as plain text
|
|
||||||
- "tool_calls": Tool calls if present (optional)
|
|
||||||
- "name": Message name if present (optional)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from langchain_core.messages import HumanMessage
|
|
||||||
>>> msg = HumanMessage(content="Hello there")
|
|
||||||
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
|
||||||
>>> result["role"]
|
|
||||||
'user'
|
|
||||||
>>> result["content"]
|
|
||||||
'Hello there'
|
|
||||||
"""
|
|
||||||
payload: Dict[str, Any] = {
|
|
||||||
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
|
||||||
"content": ChatStreamingHelper.flatten_content(
|
|
||||||
content=getattr(msg, "content", "")
|
|
||||||
),
|
|
||||||
}
|
|
||||||
tool_calls = getattr(msg, "tool_calls", None)
|
|
||||||
if tool_calls:
|
|
||||||
payload["tool_calls"] = tool_calls
|
|
||||||
name = getattr(msg, "name", None)
|
|
||||||
if name:
|
|
||||||
payload["name"] = name
|
|
||||||
return payload
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Convert a dictionary-shaped message to a normalized dictionary.
|
|
||||||
|
|
||||||
This method handles messages that come from serialized state and are
|
|
||||||
represented as dictionaries rather than BaseMessage instances. It
|
|
||||||
normalizes various dictionary formats into a consistent structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: The dictionary-shaped message to convert. Expected to contain
|
|
||||||
fields like "role", "type", "content", "text", etc.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A normalized dictionary containing:
|
|
||||||
- "role": The message role (user, assistant, system, tool)
|
|
||||||
- "content": The flattened message content as plain text
|
|
||||||
- "tool_calls": Tool calls if present (optional)
|
|
||||||
- "name": Message name if present (optional)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> obj = {"type": "human", "content": "Hello"}
|
|
||||||
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
|
||||||
>>> result["role"]
|
|
||||||
'user'
|
|
||||||
>>> result["content"]
|
|
||||||
'Hello'
|
|
||||||
"""
|
|
||||||
role: Optional[str] = obj.get("role")
|
|
||||||
if not role:
|
|
||||||
# Handle alternative type field mappings
|
|
||||||
typ = obj.get("type")
|
|
||||||
if typ in ("human", "user"):
|
|
||||||
role = "user"
|
|
||||||
elif typ in ("ai", "assistant"):
|
|
||||||
role = "assistant"
|
|
||||||
elif typ in ("system",):
|
|
||||||
role = "system"
|
|
||||||
elif typ in ("tool", "function"):
|
|
||||||
role = "tool"
|
|
||||||
|
|
||||||
content = obj.get("content")
|
|
||||||
if content is None and "text" in obj:
|
|
||||||
content = obj["text"]
|
|
||||||
|
|
||||||
out: Dict[str, Any] = {
|
|
||||||
"role": role or "assistant",
|
|
||||||
"content": ChatStreamingHelper.flatten_content(content=content),
|
|
||||||
}
|
|
||||||
if "tool_calls" in obj:
|
|
||||||
out["tool_calls"] = obj["tool_calls"]
|
|
||||||
if obj.get("name"):
|
|
||||||
out["name"] = obj["name"]
|
|
||||||
return out
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
|
||||||
"""Extract messages from LangGraph output objects.
|
|
||||||
|
|
||||||
This method handles various output formats from LangGraph execution,
|
|
||||||
extracting the messages list from different possible structures.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_obj: The output object from LangGraph execution. Can be:
|
|
||||||
- An object with a "messages" attribute
|
|
||||||
- A dictionary with a "messages" key
|
|
||||||
- Any other object (returns empty list)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of extracted messages, or an empty list if no messages
|
|
||||||
are found or if the output object is None.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
|
||||||
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
|
||||||
>>> len(messages)
|
|
||||||
1
|
|
||||||
"""
|
|
||||||
if output_obj is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Try to parse dicts first
|
|
||||||
if isinstance(output_obj, dict):
|
|
||||||
msgs = output_obj.get("messages")
|
|
||||||
return msgs if isinstance(msgs, list) else []
|
|
||||||
|
|
||||||
# Then try to get messages attribute
|
|
||||||
msgs = getattr(output_obj, "messages", None)
|
|
||||||
return msgs if isinstance(msgs, list) else []
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,197 +0,0 @@
|
||||||
from typing import AsyncIterator
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
||||||
from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Tools Table
|
|
||||||
class Tool(Base):
|
|
||||||
"""Available chatbot tools.
|
|
||||||
|
|
||||||
Stores information about all available tools that can be assigned to users.
|
|
||||||
Each tool has a unique tool_id that corresponds to the registry tool_id.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "tools"
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
|
||||||
tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
|
||||||
description: Mapped[str] = mapped_column(String(1000), nullable=False)
|
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
||||||
date_created: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
date_updated: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# User-Tool Mapping Table
|
|
||||||
class UserToolMapping(Base):
|
|
||||||
"""Mapping of users to their available tools.
|
|
||||||
|
|
||||||
Many-to-many relationship between users and tools.
|
|
||||||
- One user can have multiple tools
|
|
||||||
- One tool can be assigned to multiple users
|
|
||||||
|
|
||||||
The combination of user_id and tool_id is unique.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "user_tools"
|
|
||||||
__table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),)
|
|
||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
|
||||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False)
|
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
||||||
date_granted: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
date_updated: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# User Thread Mapping Table
|
|
||||||
class UserThreadMapping(Base):
|
|
||||||
"""Mapping of users to their chat threads.
|
|
||||||
|
|
||||||
Used to keep track of which user owns which chat thread.
|
|
||||||
Also stores meta data like thread name.
|
|
||||||
|
|
||||||
1:N relationship between user and thread. A thread belongs to exactly one user.
|
|
||||||
A user can have multiple threads.
|
|
||||||
Thread_id is unique in the table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "user_threads"
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
|
||||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
|
||||||
thread_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
date_created: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
date_updated: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
default=lambda: datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Dependency that pulls the sessionmaker off app.state
|
|
||||||
# This is set in app.py on startup in @asynccontextmanager
|
|
||||||
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
|
|
||||||
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
|
||||||
SessionLocal: async_sessionmaker[AsyncSession] = (
|
|
||||||
request.app.state.checkpoint_sessionmaker
|
|
||||||
)
|
|
||||||
async with SessionLocal() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
# Optional helper to init tables at startup (demo only)
|
|
||||||
async def init_models(engine) -> None:
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
|
|
||||||
async def sync_tools_from_registry(session: AsyncSession) -> None:
|
|
||||||
"""Sync tools from tool registry to database.
|
|
||||||
|
|
||||||
This function:
|
|
||||||
- Adds new tools from the registry to the database
|
|
||||||
- Updates existing tools with current registry information
|
|
||||||
- Marks tools not present in the registry as inactive
|
|
||||||
|
|
||||||
Should be called on application startup after database initialization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Active database session
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.info("Syncing tools from registry to database...")
|
|
||||||
|
|
||||||
# Get all tools from the registry
|
|
||||||
registry = get_registry()
|
|
||||||
registry_tools = registry.get_all_tools()
|
|
||||||
|
|
||||||
# Create a set of tool_ids from the registry
|
|
||||||
registry_tool_ids = {tool.tool_id for tool in registry_tools}
|
|
||||||
|
|
||||||
logger.info(f"Found {len(registry_tools)} tools in registry")
|
|
||||||
|
|
||||||
# Get all existing tools from the database
|
|
||||||
result = await session.execute(select(Tool))
|
|
||||||
db_tools = result.scalars().all()
|
|
||||||
db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools}
|
|
||||||
|
|
||||||
logger.info(f"Found {len(db_tools)} tools in database")
|
|
||||||
|
|
||||||
# Track changes
|
|
||||||
added_count = 0
|
|
||||||
updated_count = 0
|
|
||||||
deactivated_count = 0
|
|
||||||
|
|
||||||
# Sync tools from registry to database
|
|
||||||
for registry_tool in registry_tools:
|
|
||||||
if registry_tool.tool_id in db_tools_by_tool_id:
|
|
||||||
# Tool exists - update it
|
|
||||||
# Preserve label and description (user-editable fields)
|
|
||||||
db_tool = db_tools_by_tool_id[registry_tool.tool_id]
|
|
||||||
db_tool.name = registry_tool.name
|
|
||||||
db_tool.category = registry_tool.category
|
|
||||||
db_tool.is_active = True
|
|
||||||
db_tool.date_updated = datetime.now(timezone.utc)
|
|
||||||
updated_count += 1
|
|
||||||
logger.debug(f"Updated tool: {registry_tool.tool_id}")
|
|
||||||
else:
|
|
||||||
# Tool doesn't exist - create it
|
|
||||||
new_tool = Tool(
|
|
||||||
tool_id=registry_tool.tool_id,
|
|
||||||
name=registry_tool.name,
|
|
||||||
label=registry_tool.tool_id, # Use tool_id as label per spec
|
|
||||||
category=registry_tool.category,
|
|
||||||
description=registry_tool.description or "",
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
session.add(new_tool)
|
|
||||||
added_count += 1
|
|
||||||
logger.debug(f"Added new tool: {registry_tool.tool_id}")
|
|
||||||
|
|
||||||
# Mark tools not in registry as inactive
|
|
||||||
for db_tool in db_tools:
|
|
||||||
if db_tool.tool_id not in registry_tool_ids and db_tool.is_active:
|
|
||||||
db_tool.is_active = False
|
|
||||||
db_tool.date_updated = datetime.now(timezone.utc)
|
|
||||||
deactivated_count += 1
|
|
||||||
logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated"
|
|
||||||
)
|
|
||||||
|
|
@ -1,106 +0,0 @@
|
||||||
"""PostgreSQL checkpointer utilities for LangGraph memory."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply)
|
|
||||||
if sys.platform == 'win32':
|
|
||||||
try:
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
except RuntimeError:
|
|
||||||
pass # Already set
|
|
||||||
|
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
||||||
from psycopg_pool import AsyncConnectionPool
|
|
||||||
from psycopg.rows import dict_row
|
|
||||||
from modules.shared.configuration import APP_CONFIG
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Global checkpointer instance
|
|
||||||
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
|
|
||||||
_connection_pool: Optional[AsyncConnectionPool] = None
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_checkpointer() -> None:
|
|
||||||
"""Initialize the PostgreSQL checkpointer for LangGraph.
|
|
||||||
|
|
||||||
This should be called during application startup.
|
|
||||||
Creates a connection pool and PostgresSaver instance.
|
|
||||||
"""
|
|
||||||
global _checkpointer_instance, _connection_pool
|
|
||||||
|
|
||||||
if _checkpointer_instance is not None:
|
|
||||||
logger.info("Checkpointer already initialized")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get database configuration from environment
|
|
||||||
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
|
|
||||||
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
|
|
||||||
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
|
|
||||||
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
|
|
||||||
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
|
|
||||||
|
|
||||||
# Build connection string
|
|
||||||
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
|
|
||||||
|
|
||||||
# Create async connection pool
|
|
||||||
_connection_pool = AsyncConnectionPool(
|
|
||||||
conninfo=connection_string,
|
|
||||||
min_size=2,
|
|
||||||
max_size=10,
|
|
||||||
kwargs={"autocommit": True, "row_factory": dict_row},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the connection pool
|
|
||||||
await _connection_pool.open()
|
|
||||||
|
|
||||||
# Create AsyncPostgresSaver with the pool
|
|
||||||
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
|
|
||||||
|
|
||||||
# Setup the checkpointer (creates tables if needed)
|
|
||||||
await _checkpointer_instance.setup()
|
|
||||||
|
|
||||||
logger.info("PostgreSQL checkpointer initialized successfully")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
async def close_checkpointer() -> None:
|
|
||||||
"""Close the checkpointer and connection pool.
|
|
||||||
|
|
||||||
This should be called during application shutdown.
|
|
||||||
"""
|
|
||||||
global _checkpointer_instance, _connection_pool
|
|
||||||
|
|
||||||
if _connection_pool is not None:
|
|
||||||
try:
|
|
||||||
await _connection_pool.close()
|
|
||||||
logger.info("PostgreSQL checkpointer connection pool closed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
|
|
||||||
|
|
||||||
_checkpointer_instance = None
|
|
||||||
_connection_pool = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_checkpointer() -> AsyncPostgresSaver:
|
|
||||||
"""Get the global PostgreSQL checkpointer instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The initialized AsyncPostgresSaver instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If checkpointer is not initialized
|
|
||||||
"""
|
|
||||||
if _checkpointer_instance is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"PostgreSQL checkpointer not initialized. "
|
|
||||||
"Call initialize_checkpointer() during application startup."
|
|
||||||
)
|
|
||||||
return _checkpointer_instance
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
"""Mock permissions module for chatbot access control.
|
|
||||||
|
|
||||||
This module provides mock permission functions that will be replaced
|
|
||||||
with actual database-driven permissions in the future.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Replace these mock implementations with actual database queries
|
|
||||||
|
|
||||||
|
|
||||||
def get_chatbot_tools(*, user_id: str) -> list[str]:
|
|
||||||
"""Get list of tool IDs that the chatbot can use for a given user."""
|
|
||||||
registry = get_registry()
|
|
||||||
return registry.list_tool_ids()
|
|
||||||
|
|
||||||
|
|
||||||
def get_chatbot_model(*, user_id: str) -> str:
|
|
||||||
"""Gets the chatbot model(s) a user is allowed to use."""
|
|
||||||
return "claude_4_5"
|
|
||||||
|
|
||||||
|
|
||||||
def get_system_prompt(*, user_id: str) -> str:
|
|
||||||
"""Get the system prompt for a user's chatbot session.
|
|
||||||
|
|
||||||
This is a mock implementation that returns a generic prompt with today's date.
|
|
||||||
In production, this will query the database for user-specific or role-specific prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The unique identifier of the user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The system prompt string with the current date
|
|
||||||
"""
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
return f"You're a smart assistant. Today is {current_date}"
|
|
||||||
|
|
@ -1,305 +0,0 @@
|
||||||
"""Tool registry for auto-discovering and managing chatbot tools.
|
|
||||||
|
|
||||||
This module provides a central registry that automatically discovers all tools
|
|
||||||
in the chatbotTools directory structure and provides methods to query them.
|
|
||||||
The registry is built in-memory at startup and does not require a database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolMetadata:
|
|
||||||
"""Metadata about a discovered chatbot tool.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
tool_id: Unique identifier (e.g., 'shared.tavily_search')
|
|
||||||
name: Function name of the tool
|
|
||||||
category: Category of the tool ('shared' or 'customer')
|
|
||||||
description: Tool description from docstring
|
|
||||||
tool_instance: The actual LangChain tool instance
|
|
||||||
module_path: Full Python module path
|
|
||||||
"""
|
|
||||||
|
|
||||||
tool_id: str
|
|
||||||
name: str
|
|
||||||
category: str
|
|
||||||
description: str
|
|
||||||
tool_instance: BaseTool
|
|
||||||
module_path: str
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""Return a pretty-printed string representation for logging."""
|
|
||||||
return (
|
|
||||||
f"ToolMetadata(\n"
|
|
||||||
f" tool_id='{self.tool_id}',\n"
|
|
||||||
f" name='{self.name}',\n"
|
|
||||||
f" category='{self.category}',\n"
|
|
||||||
f" description='{self.description}',\n"
|
|
||||||
f" module_path='{self.module_path}'\n"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRegistry:
|
|
||||||
"""Central registry for all chatbot tools.
|
|
||||||
|
|
||||||
This class discovers and catalogs all tools decorated with @tool in the
|
|
||||||
chatbotTools directory structure. Tools are automatically discovered at
|
|
||||||
initialization by scanning the filesystem.
|
|
||||||
|
|
||||||
The registry provides methods to query tools by ID, category, or get all tools.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize an empty tool registry."""
|
|
||||||
self._tools: Dict[str, ToolMetadata] = {}
|
|
||||||
self._initialized: bool = False
|
|
||||||
|
|
||||||
def initialize(self) -> None:
|
|
||||||
"""Discover and register all tools from the chatbotTools directory.
|
|
||||||
|
|
||||||
This method scans both sharedTools and customerTools directories,
|
|
||||||
imports all tool*.py modules, and extracts functions decorated with @tool.
|
|
||||||
|
|
||||||
This method is idempotent - calling it multiple times has no effect
|
|
||||||
after the first initialization.
|
|
||||||
"""
|
|
||||||
if self._initialized:
|
|
||||||
logger.debug("Tool registry already initialized, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Initializing tool registry...")
|
|
||||||
|
|
||||||
# Get base path to chatbotTools directory
|
|
||||||
base_path = Path(__file__).parent.parent / "chatbotTools"
|
|
||||||
|
|
||||||
if not base_path.exists():
|
|
||||||
logger.warning(f"chatbotTools directory not found at {base_path}")
|
|
||||||
self._initialized = True
|
|
||||||
return
|
|
||||||
|
|
||||||
# Discover tools in each category
|
|
||||||
self._discover_category(
|
|
||||||
category_path=base_path / "sharedTools", category="shared"
|
|
||||||
)
|
|
||||||
self._discover_category(
|
|
||||||
category_path=base_path / "customerTools", category="customer"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._initialized = True
|
|
||||||
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
|
|
||||||
|
|
||||||
def _discover_category(self, *, category_path: Path, category: str) -> None:
|
|
||||||
"""Discover all tools in a specific category directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
category_path: Path to the category directory (sharedTools or customerTools)
|
|
||||||
category: Category name ('shared' or 'customer')
|
|
||||||
"""
|
|
||||||
if not category_path.exists():
|
|
||||||
logger.warning(f"Category directory not found: {category_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(f"Discovering tools in category: {category}")
|
|
||||||
|
|
||||||
# Find all tool*.py files (excluding __init__.py)
|
|
||||||
tool_files = [
|
|
||||||
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
|
|
||||||
]
|
|
||||||
|
|
||||||
for tool_file in tool_files:
|
|
||||||
self._import_and_register_tools(
|
|
||||||
tool_file=tool_file, category=category, category_path=category_path
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
|
|
||||||
|
|
||||||
def _import_and_register_tools(
|
|
||||||
self, *, tool_file: Path, category: str, category_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Import a tool module and register all discovered tools.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_file: Path to the tool Python file
|
|
||||||
category: Category name ('shared' or 'customer')
|
|
||||||
category_path: Path to the category directory
|
|
||||||
"""
|
|
||||||
# Construct module name
|
|
||||||
module_name = (
|
|
||||||
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Import the module
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# Find all BaseTool instances in the module
|
|
||||||
tools_found = 0
|
|
||||||
for name, obj in inspect.getmembers(module):
|
|
||||||
if isinstance(obj, BaseTool):
|
|
||||||
self._register_tool(
|
|
||||||
tool_instance=obj,
|
|
||||||
name=name,
|
|
||||||
category=category,
|
|
||||||
module_path=module_name,
|
|
||||||
)
|
|
||||||
tools_found += 1
|
|
||||||
|
|
||||||
if tools_found == 0:
|
|
||||||
logger.warning(f"No tools found in {module_name}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
logger.error(
|
|
||||||
f"Import error loading tools from {module_name}: {str(e)}. "
|
|
||||||
f"This tool will not be available."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _register_tool(
|
|
||||||
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
|
|
||||||
) -> None:
|
|
||||||
"""Register a single tool in the registry.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_instance: The LangChain tool instance
|
|
||||||
name: Function name of the tool
|
|
||||||
category: Category name ('shared' or 'customer')
|
|
||||||
module_path: Full Python module path
|
|
||||||
"""
|
|
||||||
tool_id = f"{category}.{name}"
|
|
||||||
|
|
||||||
# Check for duplicate tool IDs
|
|
||||||
if tool_id in self._tools:
|
|
||||||
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
|
|
||||||
|
|
||||||
metadata = ToolMetadata(
|
|
||||||
tool_id=tool_id,
|
|
||||||
name=name,
|
|
||||||
category=category,
|
|
||||||
description=tool_instance.description or "",
|
|
||||||
tool_instance=tool_instance,
|
|
||||||
module_path=module_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._tools[tool_id] = metadata
|
|
||||||
logger.debug(f"Registered tool: {tool_id}")
|
|
||||||
|
|
||||||
def get_all_tools(self) -> List[ToolMetadata]:
|
|
||||||
"""Get all registered tools.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all tool metadata objects
|
|
||||||
"""
|
|
||||||
return list(self._tools.values())
|
|
||||||
|
|
||||||
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
|
|
||||||
"""Get a specific tool by its ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_id: The tool identifier (e.g., 'shared.tavily_search')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tool metadata if found, None otherwise
|
|
||||||
"""
|
|
||||||
return self._tools.get(tool_id)
|
|
||||||
|
|
||||||
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
|
|
||||||
"""Get all tools in a specific category.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
category: Category name ('shared' or 'customer')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of tool metadata for the specified category
|
|
||||||
"""
|
|
||||||
return [t for t in self._tools.values() if t.category == category]
|
|
||||||
|
|
||||||
def list_tool_ids(self) -> List[str]:
|
|
||||||
"""Get a list of all registered tool IDs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of tool ID strings
|
|
||||||
"""
|
|
||||||
return list(self._tools.keys())
|
|
||||||
|
|
||||||
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
|
|
||||||
"""Get actual tool instances for a list of tool IDs.
|
|
||||||
|
|
||||||
This is useful for filtering tools based on user permissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_ids: List of tool IDs to retrieve
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of BaseTool instances for the specified IDs
|
|
||||||
"""
|
|
||||||
instances = []
|
|
||||||
for tool_id in tool_ids:
|
|
||||||
metadata = self.get_tool(tool_id=tool_id)
|
|
||||||
if metadata:
|
|
||||||
instances.append(metadata.tool_instance)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Tool ID not found in registry: {tool_id}")
|
|
||||||
return instances
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_initialized(self) -> bool:
|
|
||||||
"""Check if the registry has been initialized.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if initialized, False otherwise
|
|
||||||
"""
|
|
||||||
return self._initialized
|
|
||||||
|
|
||||||
|
|
||||||
# Global registry instance
|
|
||||||
_registry: Optional[ToolRegistry] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_registry() -> ToolRegistry:
|
|
||||||
"""Get the global tool registry instance.
|
|
||||||
|
|
||||||
This function ensures the registry is initialized on first access.
|
|
||||||
Subsequent calls return the same instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The global ToolRegistry instance
|
|
||||||
"""
|
|
||||||
global _registry
|
|
||||||
|
|
||||||
if _registry is None:
|
|
||||||
_registry = ToolRegistry()
|
|
||||||
|
|
||||||
if not _registry.is_initialized:
|
|
||||||
_registry.initialize()
|
|
||||||
|
|
||||||
return _registry
|
|
||||||
|
|
||||||
|
|
||||||
def reinitialize_registry() -> ToolRegistry:
|
|
||||||
"""Force reinitialize the tool registry.
|
|
||||||
|
|
||||||
This is useful for testing or when tools are added dynamically.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The reinitialized ToolRegistry instance
|
|
||||||
"""
|
|
||||||
global _registry
|
|
||||||
_registry = ToolRegistry()
|
|
||||||
_registry.initialize()
|
|
||||||
return _registry
|
|
||||||
|
|
@ -13,14 +13,15 @@ async def start() -> None:
|
||||||
from modules.features.syncDelta import mainSyncDelta
|
from modules.features.syncDelta import mainSyncDelta
|
||||||
mainSyncDelta.startSyncManager(eventUser)
|
mainSyncDelta.startSyncManager(eventUser)
|
||||||
|
|
||||||
# Feature ChatBot
|
# Feature ...
|
||||||
from modules.features.chatBot.mainChatBot import start as startChatBot
|
|
||||||
await startChatBot()
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def stop() -> None:
|
async def stop() -> None:
|
||||||
""" Stop feature triggers and background managers """
|
""" Stop feature triggers and background managers """
|
||||||
|
|
||||||
# Feature ChatBot
|
# Feature ...
|
||||||
from modules.features.chatBot.mainChatBot import stop as stopChatBot
|
|
||||||
await stopChatBot()
|
return True
|
||||||
|
|
|
||||||
|
|
@ -1,653 +0,0 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from fastapi.requests import Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from modules.datamodels.datamodelUam import User, UserPrivilege
|
|
||||||
from modules.security.auth import getCurrentUser, limiter
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from modules.features.chatBot.subChatbotDatabase import get_async_db_session
|
|
||||||
from modules.features.chatBot.mainChatBot import (
|
|
||||||
get_or_create_thread_for_user,
|
|
||||||
)
|
|
||||||
from modules.datamodels.datamodelChatbot import (
|
|
||||||
ChatMessageRequest,
|
|
||||||
MessageItem,
|
|
||||||
ChatMessageResponse,
|
|
||||||
ThreadSummary,
|
|
||||||
ThreadListResponse,
|
|
||||||
ThreadDetail,
|
|
||||||
RenameThreadRequest,
|
|
||||||
DeleteResponse,
|
|
||||||
ToolListResponse,
|
|
||||||
ToolInfo,
|
|
||||||
GrantToolRequest,
|
|
||||||
GrantToolResponse,
|
|
||||||
RevokeToolRequest,
|
|
||||||
RevokeToolResponse,
|
|
||||||
UpdateToolRequest,
|
|
||||||
UpdateToolResponse,
|
|
||||||
)
|
|
||||||
from modules.features.chatBot import mainChatBot as chat_service
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/api/chatbot",
|
|
||||||
tags=["Chatbot"],
|
|
||||||
responses={404: {"description": "Not found"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Actual endpoints for chatbot ---
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/message/stream")
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def post_chat_message_stream(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
message_request: ChatMessageRequest,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> StreamingResponse:
|
|
||||||
"""
|
|
||||||
Post a message to a chat thread with streaming progress updates.
|
|
||||||
Creates a new thread if thread_id is not provided.
|
|
||||||
|
|
||||||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Validate and get tools for the request
|
|
||||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
|
||||||
user_id=currentUser.id,
|
|
||||||
requested_tool_ids=message_request.tools,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get or create thread using helper function
|
|
||||||
thread_id = await get_or_create_thread_for_user(
|
|
||||||
thread_id=message_request.thread_id,
|
|
||||||
user=currentUser,
|
|
||||||
session=session,
|
|
||||||
thread_name=message_request.message[:100],
|
|
||||||
refresh_date_updated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
chat_service.post_message_stream(
|
|
||||||
thread_id=thread_id,
|
|
||||||
message=message_request.message,
|
|
||||||
user=currentUser,
|
|
||||||
tool_ids=tool_ids,
|
|
||||||
),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
except PermissionError as e:
|
|
||||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=str(e) or "Permission denied",
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=str(e) or "Permission denied",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/message", response_model=ChatMessageResponse)
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def post_chat_message(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
message_request: ChatMessageRequest,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> ChatMessageResponse:
|
|
||||||
"""
|
|
||||||
Post a message to a chat thread and get assistant response (non-streaming).
|
|
||||||
Creates a new thread if thread_id is not provided.
|
|
||||||
|
|
||||||
For streaming updates, use the /message/stream endpoint instead.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Validate and get tools for the request
|
|
||||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
|
||||||
user_id=currentUser.id,
|
|
||||||
requested_tool_ids=message_request.tools,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get or create thread using helper function
|
|
||||||
thread_id = await get_or_create_thread_for_user(
|
|
||||||
thread_id=message_request.thread_id,
|
|
||||||
user=currentUser,
|
|
||||||
session=session,
|
|
||||||
thread_name=message_request.message[:100],
|
|
||||||
refresh_date_updated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
|
||||||
|
|
||||||
response = await chat_service.post_message(
|
|
||||||
thread_id=thread_id,
|
|
||||||
message=message_request.message,
|
|
||||||
user=currentUser,
|
|
||||||
tool_ids=tool_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except PermissionError as e:
|
|
||||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=str(e) or "Permission denied",
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=str(e) or "Permission denied",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/threads", response_model=ThreadListResponse)
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def get_all_threads(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> ThreadListResponse:
|
|
||||||
"""
|
|
||||||
Get all chat threads for the current user.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get all threads for the current user
|
|
||||||
threads = await chat_service.get_all_threads_for_user(
|
|
||||||
user=currentUser, session=session
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} retrieved {len(threads)} threads")
|
|
||||||
|
|
||||||
return ThreadListResponse(threads=threads)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def get_thread_by_id(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
thread_id: str,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> ThreadDetail:
|
|
||||||
"""
|
|
||||||
Get a specific chat thread with all its messages from LangGraph checkpointer.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
thread_detail = await chat_service.get_thread_detail_for_user(
|
|
||||||
thread_id=thread_id,
|
|
||||||
user=currentUser,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
|
||||||
return thread_detail
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Thread not found: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e) or "Thread not found",
|
|
||||||
)
|
|
||||||
except PermissionError as e:
|
|
||||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=str(e) or "Permission denied",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/threads/{thread_id}", response_model=DeleteResponse)
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def rename_thread(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
thread_id: str,
|
|
||||||
rename_request: RenameThreadRequest,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> DeleteResponse:
|
|
||||||
"""
|
|
||||||
Rename a chat thread.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await chat_service.update_thread_name(
|
|
||||||
thread_id=thread_id,
|
|
||||||
user=currentUser,
|
|
||||||
new_thread_name=rename_request.new_name,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
return DeleteResponse(
|
|
||||||
message=f"Thread {thread_id} successfully renamed",
|
|
||||||
thread_id=thread_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e) or "Thread not found or permission denied",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/threads/{thread_id}", response_model=DeleteResponse)
|
|
||||||
@limiter.limit("10/minute")
|
|
||||||
async def delete_thread(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
thread_id: str,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> DeleteResponse:
|
|
||||||
"""
|
|
||||||
Delete a chat thread and all its associated data from both LangGraph and database.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await chat_service.delete_thread_for_user(
|
|
||||||
thread_id=thread_id,
|
|
||||||
user=currentUser,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} deleted thread {thread_id}")
|
|
||||||
|
|
||||||
return DeleteResponse(
|
|
||||||
message=f"Thread {thread_id} successfully deleted",
|
|
||||||
thread_id=thread_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e) or "Thread not found or permission denied",
|
|
||||||
)
|
|
||||||
except PermissionError as e:
|
|
||||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=str(e) or "Permission denied",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Tool Management Endpoints
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tools", response_model=ToolListResponse)
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def get_all_tools(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> ToolListResponse:
|
|
||||||
"""
|
|
||||||
Get all available chatbot tools.
|
|
||||||
Only accessible to system administrators.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check SYSADMIN permission
|
|
||||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Only system administrators can view tools",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all tools from service
|
|
||||||
tools_data = await chat_service.get_all_tools(session=session)
|
|
||||||
|
|
||||||
# Convert to ToolInfo objects
|
|
||||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} retrieved {len(tools)} tools")
|
|
||||||
|
|
||||||
return ToolListResponse(tools=tools)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tools/grant", response_model=GrantToolResponse)
|
|
||||||
@limiter.limit("10/minute")
|
|
||||||
async def grant_tool_to_user(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
grant_request: GrantToolRequest,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> GrantToolResponse:
|
|
||||||
"""
|
|
||||||
Grant a tool to a user.
|
|
||||||
Only accessible to system administrators.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check SYSADMIN permission
|
|
||||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Only system administrators can grant tools",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Grant the tool
|
|
||||||
await chat_service.grant_tool_to_user(
|
|
||||||
user_id=grant_request.user_id,
|
|
||||||
tool_id=grant_request.tool_id,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return GrantToolResponse(
|
|
||||||
message=f"Tool successfully granted to user {grant_request.user_id}",
|
|
||||||
user_id=grant_request.user_id,
|
|
||||||
tool_id=grant_request.tool_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=str(e) or "Invalid request",
|
|
||||||
)
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/tools/revoke", response_model=RevokeToolResponse)
|
|
||||||
@limiter.limit("10/minute")
|
|
||||||
async def revoke_tool_from_user(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
revoke_request: RevokeToolRequest,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> RevokeToolResponse:
|
|
||||||
"""
|
|
||||||
Revoke a tool from a user.
|
|
||||||
Only accessible to system administrators.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check SYSADMIN permission
|
|
||||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Only system administrators can revoke tools",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Revoke the tool
|
|
||||||
await chat_service.revoke_tool_from_user(
|
|
||||||
user_id=revoke_request.user_id,
|
|
||||||
tool_id=revoke_request.tool_id,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return RevokeToolResponse(
|
|
||||||
message=f"Tool successfully revoked from user {revoke_request.user_id}",
|
|
||||||
user_id=revoke_request.user_id,
|
|
||||||
tool_id=revoke_request.tool_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=str(e) or "Invalid request",
|
|
||||||
)
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse)
|
|
||||||
@limiter.limit("10/minute")
|
|
||||||
async def update_tool(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
tool_id: str,
|
|
||||||
update_request: UpdateToolRequest,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> UpdateToolResponse:
|
|
||||||
"""
|
|
||||||
Update a tool's label and/or description.
|
|
||||||
Only accessible to system administrators.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check SYSADMIN permission
|
|
||||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Only system administrators can update tools",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the tool
|
|
||||||
updated_fields = await chat_service.update_tool(
|
|
||||||
tool_id=tool_id,
|
|
||||||
label=update_request.label,
|
|
||||||
description=update_request.description,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return UpdateToolResponse(
|
|
||||||
message="Tool successfully updated",
|
|
||||||
tool_id=tool_id,
|
|
||||||
updated_fields=updated_fields,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=str(e) or "Invalid request",
|
|
||||||
)
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tools/user/{user_id}", response_model=ToolListResponse)
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def get_tools_for_specific_user(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
user_id: str,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> ToolListResponse:
|
|
||||||
"""
|
|
||||||
Get all tools granted to a specific user.
|
|
||||||
Only accessible to system administrators.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check SYSADMIN permission
|
|
||||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Only system administrators can view user tools",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get tools for the specified user
|
|
||||||
tools_data = await chat_service.get_tools_for_user(
|
|
||||||
user_id=user_id, session=session
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to ToolInfo objects
|
|
||||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ToolListResponse(tools=tools)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tools/me", response_model=ToolListResponse)
|
|
||||||
@limiter.limit("30/minute")
|
|
||||||
async def get_my_tools(
|
|
||||||
*,
|
|
||||||
request: Request,
|
|
||||||
currentUser: User = Depends(getCurrentUser),
|
|
||||||
session: AsyncSession = Depends(get_async_db_session),
|
|
||||||
) -> ToolListResponse:
|
|
||||||
"""
|
|
||||||
Get all tools the current user has access to.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get tools for the current user
|
|
||||||
tools_data = await chat_service.get_tools_for_user(
|
|
||||||
user_id=currentUser.id, session=session
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to ToolInfo objects
|
|
||||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {currentUser.id} retrieved {len(tools)} tools for themselves"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ToolListResponse(tools=tools)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
|
||||||
)
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, UTC
|
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue