removed chatbot on langgraph platform
This commit is contained in:
parent
711b8bc50d
commit
f05d958213
20 changed files with 7 additions and 3922 deletions
2
app.py
2
app.py
|
|
@ -422,5 +422,3 @@ app.include_router(voiceGoogleRouter)
|
|||
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
|
||||
app.include_router(adminSecurityRouter)
|
||||
|
||||
# from modules.routes.routeChatbot import router as chatbotRouter
|
||||
# app.include_router(chatbotRouter)
|
||||
|
|
|
|||
|
|
@ -1,216 +0,0 @@
|
|||
"""Chatbot API models for request/response handling."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
|
||||
|
||||
# Chatbot API Models
|
||||
class MessageItem(BaseModel, ModelMixin):
|
||||
"""Individual message in a thread"""
|
||||
|
||||
role: str = Field(..., description="Message role (user or assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel, ModelMixin):
|
||||
"""Request model for posting a chat message"""
|
||||
|
||||
thread_id: Optional[str] = Field(
|
||||
None, description="Thread ID (creates new thread if not provided)"
|
||||
)
|
||||
message: str = Field(..., description="User message content")
|
||||
tools: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of tool IDs to use. If not provided, all user's tools will be used",
|
||||
)
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel, ModelMixin):
|
||||
"""Response model for posting a chat message"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
||||
|
||||
|
||||
class ThreadSummary(BaseModel, ModelMixin):
|
||||
"""Summary of a chat thread for list view"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
thread_name: str = Field(..., description="Thread name")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
|
||||
|
||||
class ThreadListResponse(BaseModel, ModelMixin):
|
||||
"""Response model for listing all threads"""
|
||||
|
||||
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
||||
|
||||
|
||||
class ThreadDetail(BaseModel, ModelMixin):
|
||||
"""Detailed view of a single thread"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
messages: List[MessageItem] = Field(
|
||||
..., description="All messages in chronological order"
|
||||
)
|
||||
|
||||
|
||||
class RenameThreadRequest(BaseModel, ModelMixin):
|
||||
"""Request model for renaming a thread"""
|
||||
|
||||
new_name: str = Field(..., description="New name for the thread")
|
||||
|
||||
|
||||
class DeleteResponse(BaseModel, ModelMixin):
|
||||
"""Response model for delete operations"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
thread_id: str = Field(..., description="Deleted thread ID")
|
||||
|
||||
|
||||
# Tool Management Models
|
||||
class ToolInfo(BaseModel, ModelMixin):
|
||||
"""Information about a chatbot tool"""
|
||||
|
||||
id: str = Field(..., description="Tool UUID")
|
||||
tool_id: str = Field(
|
||||
..., description="Tool identifier (e.g., 'shared.tavily_search')"
|
||||
)
|
||||
name: str = Field(..., description="Tool function name")
|
||||
label: str = Field(..., description="Display label for the tool")
|
||||
category: str = Field(..., description="Tool category (shared or customer)")
|
||||
description: str = Field(..., description="Tool description")
|
||||
is_active: bool = Field(..., description="Whether the tool is active")
|
||||
date_created: float = Field(..., description="Creation timestamp")
|
||||
date_updated: float = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel, ModelMixin):
|
||||
"""Response model for listing all tools"""
|
||||
|
||||
tools: List[ToolInfo] = Field(..., description="List of available tools")
|
||||
|
||||
|
||||
class GrantToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for granting a tool to a user"""
|
||||
|
||||
user_id: str = Field(..., description="User ID to grant the tool to")
|
||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
||||
|
||||
|
||||
class GrantToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after granting a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
|
||||
|
||||
class RevokeToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for revoking a tool from a user"""
|
||||
|
||||
user_id: str = Field(..., description="User ID to revoke the tool from")
|
||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
||||
|
||||
|
||||
class RevokeToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after revoking a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
|
||||
|
||||
class UpdateToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for updating a tool's label and description"""
|
||||
|
||||
label: Optional[str] = Field(None, description="New label for the tool")
|
||||
description: Optional[str] = Field(None, description="New description for the tool")
|
||||
|
||||
|
||||
class UpdateToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after updating a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
updated_fields: List[str] = Field(..., description="List of updated field names")
|
||||
|
||||
|
||||
# Register model labels for internationalization
|
||||
register_model_labels(
|
||||
"MessageItem",
|
||||
{"en": "Message Item", "fr": "Élément de message"},
|
||||
{
|
||||
"role": {"en": "Role", "fr": "Rôle"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageRequest",
|
||||
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageResponse",
|
||||
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadSummary",
|
||||
{"en": "Thread Summary", "fr": "Résumé du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"thread_name": {"en": "Thread Name", "fr": "Nom du fil"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadListResponse",
|
||||
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
|
||||
{
|
||||
"threads": {"en": "Threads", "fr": "Fils"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadDetail",
|
||||
{"en": "Thread Detail", "fr": "Détail du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"RenameThreadRequest",
|
||||
{"en": "Rename Thread Request", "fr": "Demande de renommage de fil"},
|
||||
{
|
||||
"new_name": {"en": "New Name", "fr": "Nouveau nom"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"DeleteResponse",
|
||||
{"en": "Delete Response", "fr": "Réponse de suppression"},
|
||||
{
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
},
|
||||
)
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""Contains all tools available for the chatbot to use."""
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""Tools that are shared between multiple customers go here."""
|
||||
|
|
@ -1,208 +0,0 @@
|
|||
"""Althaus Database Query Tool for LangGraph.
|
||||
|
||||
This tool provides database query capabilities for the Althaus database
|
||||
via an external REST API. Only SELECT queries are allowed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _mock_api_call(*, sql_query: str) -> dict:
|
||||
"""Mock the external REST API call to Althaus database.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute
|
||||
|
||||
Returns:
|
||||
A dictionary containing the query results with columns and rows
|
||||
"""
|
||||
# Simulate network delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Mock response data based on common query patterns
|
||||
if "users" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["id", "username", "email", "created_at"],
|
||||
"rows": [
|
||||
[1, "john_doe", "john@example.com", "2024-01-15"],
|
||||
[2, "jane_smith", "jane@example.com", "2024-02-20"],
|
||||
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "products" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["product_id", "name", "price", "stock"],
|
||||
"rows": [
|
||||
[101, "Widget A", 29.99, 150],
|
||||
[102, "Widget B", 39.99, 75],
|
||||
[103, "Widget C", 19.99, 200],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "orders" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["order_id", "customer_id", "total", "status"],
|
||||
"rows": [
|
||||
[5001, 1, 129.99, "completed"],
|
||||
[5002, 2, 89.50, "pending"],
|
||||
[5003, 1, 199.99, "shipped"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
else:
|
||||
# Generic response for other queries
|
||||
return {
|
||||
"columns": ["id", "value", "description"],
|
||||
"rows": [
|
||||
[1, "Sample 1", "First sample entry"],
|
||||
[2, "Sample 2", "Second sample entry"],
|
||||
],
|
||||
"row_count": 2,
|
||||
}
|
||||
|
||||
|
||||
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
|
||||
"""Validate that the query is a SELECT statement only.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL query to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Remove leading/trailing whitespace and convert to lowercase for checking
|
||||
normalized_query = sql_query.strip().lower()
|
||||
|
||||
# Check if query starts with SELECT
|
||||
if not normalized_query.startswith("select"):
|
||||
return False, "Query must be a SELECT statement"
|
||||
|
||||
# Check for dangerous keywords that should not be in a SELECT query
|
||||
dangerous_keywords = [
|
||||
"insert",
|
||||
"update",
|
||||
"delete",
|
||||
"drop",
|
||||
"create",
|
||||
"alter",
|
||||
"truncate",
|
||||
"grant",
|
||||
"revoke",
|
||||
"exec",
|
||||
"execute",
|
||||
]
|
||||
|
||||
for keyword in dangerous_keywords:
|
||||
# Use word boundary to match whole words only
|
||||
if re.search(rf"\b{keyword}\b", normalized_query):
|
||||
return False, f"Query contains forbidden keyword: {keyword.upper()}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
|
||||
"""Format query results into a readable string.
|
||||
|
||||
Args:
|
||||
columns: List of column names
|
||||
rows: List of row data
|
||||
row_count: Total number of rows
|
||||
|
||||
Returns:
|
||||
Formatted string representation of the results
|
||||
"""
|
||||
if row_count == 0:
|
||||
return "Query executed successfully but returned no results."
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(str(col)) for col in columns]
|
||||
for row in rows:
|
||||
for i, cell in enumerate(row):
|
||||
col_widths[i] = max(col_widths[i], len(str(cell)))
|
||||
|
||||
# Build header
|
||||
header_parts = []
|
||||
for col, width in zip(columns, col_widths):
|
||||
header_parts.append(str(col).ljust(width))
|
||||
header = " | ".join(header_parts)
|
||||
separator = "-" * len(header)
|
||||
|
||||
# Build rows
|
||||
row_lines = []
|
||||
for row in rows:
|
||||
row_parts = []
|
||||
for cell, width in zip(row, col_widths):
|
||||
row_parts.append(str(cell).ljust(width))
|
||||
row_lines.append(" | ".join(row_parts))
|
||||
|
||||
# Combine all parts
|
||||
result_parts = [
|
||||
f"Query returned {row_count} row(s):\n",
|
||||
header,
|
||||
separator,
|
||||
"\n".join(row_lines),
|
||||
]
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
@tool
|
||||
async def query_althaus_database(
|
||||
sql_query: Annotated[
|
||||
str, "The SQL SELECT query to execute against the Althaus database"
|
||||
],
|
||||
) -> str:
|
||||
"""Execute a SELECT query against the Althaus database via REST API.
|
||||
|
||||
Use this tool to query data from the Althaus database. Only SELECT statements
|
||||
are allowed for security reasons. The query will be forwarded to an external
|
||||
REST API and the results will be returned in a formatted table.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
|
||||
|
||||
Returns:
|
||||
A formatted string containing the query results with columns and rows
|
||||
"""
|
||||
try:
|
||||
# Validate the query
|
||||
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
|
||||
if not is_valid:
|
||||
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
|
||||
|
||||
# Mock the external REST API call
|
||||
# In production, this would be replaced with actual REST API call:
|
||||
# response = await httpx.AsyncClient().post(
|
||||
# "https://api.althaus.example.com/query",
|
||||
# json={"query": sql_query},
|
||||
# headers={"Authorization": f"Bearer {api_key}"}
|
||||
# )
|
||||
# result = response.json()
|
||||
|
||||
result = await _mock_api_call(sql_query=sql_query)
|
||||
|
||||
# Format and return results
|
||||
formatted_output = _format_results(
|
||||
columns=result["columns"],
|
||||
rows=result["rows"],
|
||||
row_count=result["row_count"],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Query completed successfully, returned {result['row_count']} row(s)"
|
||||
)
|
||||
return formatted_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in query_althaus_database tool: {str(e)}")
|
||||
return f"Error executing query: {str(e)}"
|
||||
|
|
@ -1,362 +0,0 @@
|
|||
"""Power BI Query Tool for LangGraph.
|
||||
|
||||
This tool provides DAX query capabilities for Power BI datasets
|
||||
via the Power BI REST API. Only read-only queries are allowed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import functools
|
||||
from typing import Annotated
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from msal import ConfidentialClientApplication, SerializableTokenCache
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration constants - encapsulated in this file
|
||||
POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "")
|
||||
POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "")
|
||||
POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "")
|
||||
POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "")
|
||||
POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg"
|
||||
POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com"
|
||||
POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"]
|
||||
|
||||
# Limit results to prevent excessive context usage
|
||||
MAX_ROWS_LIMIT = 100
|
||||
|
||||
|
||||
def _validate_environment() -> tuple[bool, str]:
|
||||
"""Validate that all required environment variables are set.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
missing = []
|
||||
if not POWERBI_DATASET_ID:
|
||||
missing.append("POWERBI_DATASET_ID")
|
||||
if not POWERBI_CLIENT_ID:
|
||||
missing.append("POWERBI_CLIENT_ID")
|
||||
if not POWERBI_CLIENT_SECRET:
|
||||
missing.append("POWERBI_CLIENT_SECRET")
|
||||
if not POWERBI_TENANT_ID:
|
||||
missing.append("POWERBI_TENANT_ID")
|
||||
|
||||
if missing:
|
||||
return False, f"Missing required environment variables: {', '.join(missing)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]:
|
||||
"""Validate that the query is a valid DAX query.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Remove leading/trailing whitespace
|
||||
normalized_query = dax_query.strip()
|
||||
|
||||
if not normalized_query:
|
||||
return False, "Query cannot be empty"
|
||||
|
||||
# DAX queries typically start with EVALUATE, DEFINE, or are table expressions
|
||||
# We'll be lenient and just check it's not trying to do something dangerous
|
||||
# DAX is read-only by nature, but we validate structure
|
||||
|
||||
# Check for minimum length
|
||||
if len(normalized_query) < 5:
|
||||
return False, "Query is too short to be valid"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _get_access_token_sync(
|
||||
*,
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
authority_base: str = POWERBI_AUTHORITY_BASE,
|
||||
cache: SerializableTokenCache | None = None,
|
||||
) -> str:
|
||||
"""Get Power BI access token using MSAL (synchronous).
|
||||
|
||||
Args:
|
||||
tenant_id: Azure AD tenant ID
|
||||
client_id: Application client ID
|
||||
client_secret: Application client secret
|
||||
authority_base: Azure AD authority base URL
|
||||
cache: Optional token cache for reuse
|
||||
|
||||
Returns:
|
||||
Access token string
|
||||
|
||||
Raises:
|
||||
RuntimeError: If token acquisition fails
|
||||
"""
|
||||
authority = f"{authority_base}/{tenant_id}"
|
||||
|
||||
app = ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
authority=authority,
|
||||
client_credential=client_secret,
|
||||
token_cache=cache,
|
||||
)
|
||||
|
||||
# Try cache first; fall back to client credentials
|
||||
result = app.acquire_token_silent(
|
||||
POWERBI_SCOPE, account=None
|
||||
) or app.acquire_token_for_client(scopes=POWERBI_SCOPE)
|
||||
|
||||
if "access_token" not in result:
|
||||
raise RuntimeError(
|
||||
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
|
||||
)
|
||||
|
||||
return result["access_token"]
|
||||
|
||||
|
||||
async def _get_access_token_async(
|
||||
*,
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Get Power BI access token using MSAL (asynchronous).
|
||||
|
||||
Args:
|
||||
tenant_id: Azure AD tenant ID
|
||||
client_id: Application client ID
|
||||
client_secret: Application client secret
|
||||
**kwargs: Additional arguments for _get_access_token_sync
|
||||
|
||||
Returns:
|
||||
Access token string
|
||||
"""
|
||||
# Create a partial function with arguments pre-filled
|
||||
func = functools.partial(
|
||||
_get_access_token_sync,
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
**kwargs,
|
||||
)
|
||||
# Offload the blocking MSAL HTTP call to a worker thread
|
||||
return await anyio.to_thread.run_sync(func)
|
||||
|
||||
|
||||
async def _execute_dax_query(
|
||||
*, dax_query: str, dataset_id: str, access_token: str
|
||||
) -> dict:
|
||||
"""Execute a DAX query against Power BI dataset.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to execute
|
||||
dataset_id: Power BI dataset ID
|
||||
access_token: Access token for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary containing query results
|
||||
|
||||
Raises:
|
||||
RuntimeError: If query execution fails
|
||||
"""
|
||||
url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries"
|
||||
|
||||
body = {
|
||||
"queries": [{"query": dax_query}],
|
||||
"serializerSettings": {"includeNulls": True},
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(url, headers=headers, json=body)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
|
||||
)
|
||||
|
||||
payload = resp.json()
|
||||
|
||||
try:
|
||||
rows = payload["results"][0]["tables"][0]["rows"]
|
||||
except (KeyError, IndexError) as e:
|
||||
raise RuntimeError("Unexpected executeQueries response structure") from e
|
||||
|
||||
# Extract column names from the first row if available
|
||||
if rows:
|
||||
columns = list(rows[0].keys())
|
||||
else:
|
||||
columns = []
|
||||
|
||||
return {"columns": columns, "rows": rows}
|
||||
|
||||
|
||||
def _strip_table_qualifier(*, column_name: str) -> str:
|
||||
"""Strip table qualifier from column name.
|
||||
|
||||
Power BI often returns columns as 'Table[Column]'. This strips to 'Column'.
|
||||
|
||||
Args:
|
||||
column_name: The column name to process
|
||||
|
||||
Returns:
|
||||
Processed column name
|
||||
"""
|
||||
if "[" in column_name and column_name.endswith("]"):
|
||||
return column_name.split("[", 1)[1][:-1]
|
||||
return column_name
|
||||
|
||||
|
||||
def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str:
|
||||
"""Format query results into a readable string.
|
||||
|
||||
Args:
|
||||
columns: List of column names
|
||||
rows: List of row data (as dictionaries)
|
||||
max_rows: Maximum number of rows to display
|
||||
|
||||
Returns:
|
||||
Formatted string representation of the results
|
||||
"""
|
||||
total_rows = len(rows)
|
||||
|
||||
if total_rows == 0:
|
||||
return "Query executed successfully but returned no results."
|
||||
|
||||
# Strip table qualifiers from column names
|
||||
clean_columns = [_strip_table_qualifier(column_name=col) for col in columns]
|
||||
|
||||
# Limit rows to max_rows
|
||||
display_rows = rows[:max_rows]
|
||||
truncated = total_rows > max_rows
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(str(col)) for col in clean_columns]
|
||||
for row in display_rows:
|
||||
for i, col in enumerate(columns):
|
||||
value = row.get(col, "")
|
||||
col_widths[i] = max(col_widths[i], len(str(value)))
|
||||
|
||||
# Build header
|
||||
header_parts = []
|
||||
for col, width in zip(clean_columns, col_widths):
|
||||
header_parts.append(str(col).ljust(width))
|
||||
header = " | ".join(header_parts)
|
||||
separator = "-" * len(header)
|
||||
|
||||
# Build rows
|
||||
row_lines = []
|
||||
for row in display_rows:
|
||||
row_parts = []
|
||||
for col, width in zip(columns, col_widths):
|
||||
value = row.get(col, "")
|
||||
row_parts.append(str(value).ljust(width))
|
||||
row_lines.append(" | ".join(row_parts))
|
||||
|
||||
# Combine all parts
|
||||
result_parts = [
|
||||
f"Query returned {total_rows} row(s):",
|
||||
]
|
||||
|
||||
if truncated:
|
||||
result_parts.append(
|
||||
f"(Results limited to {max_rows} rows for context efficiency)\n"
|
||||
)
|
||||
else:
|
||||
result_parts.append("")
|
||||
|
||||
result_parts.extend([header, separator, "\n".join(row_lines)])
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
@tool
|
||||
async def query_powerbi_data(
|
||||
dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"],
|
||||
) -> str:
|
||||
"""Execute a DAX query against the Power BI dataset to access warehouse inventory data.
|
||||
|
||||
This tool provides access to a Power BI table called 'data_full' which contains
|
||||
articles available in the warehouse of the user. Use DAX (Data Analysis Expressions)
|
||||
queries to retrieve and analyze this inventory data.
|
||||
|
||||
Available table:
|
||||
- 'data_full': Contains warehouse inventory articles and their details
|
||||
|
||||
Common query patterns:
|
||||
- View all data: EVALUATE 'data_full'
|
||||
- With filter: EVALUATE FILTER('data_full', [Column] = "Value")
|
||||
- Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC)
|
||||
- Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2]))
|
||||
|
||||
Results are limited to 100 rows maximum for efficiency.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'")
|
||||
|
||||
Returns:
|
||||
A formatted string containing the query results with columns and rows
|
||||
"""
|
||||
try:
|
||||
# Validate environment configuration
|
||||
is_valid_env, error_msg = _validate_environment()
|
||||
if not is_valid_env:
|
||||
logger.error(f"Environment validation failed: {error_msg}")
|
||||
return f"Configuration Error: {error_msg}"
|
||||
|
||||
# Validate the query
|
||||
is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query)
|
||||
if not is_valid_query:
|
||||
logger.warning(f"Invalid query attempt: {dax_query[:100]}...")
|
||||
return f"Query Validation Error: {error_msg}"
|
||||
|
||||
logger.info(f"Executing Power BI query: {dax_query[:100]}...")
|
||||
|
||||
# Get access token
|
||||
access_token = await _get_access_token_async(
|
||||
tenant_id=POWERBI_TENANT_ID,
|
||||
client_id=POWERBI_CLIENT_ID,
|
||||
client_secret=POWERBI_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await _execute_dax_query(
|
||||
dax_query=dax_query,
|
||||
dataset_id=POWERBI_DATASET_ID,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Format and return results
|
||||
formatted_output = _format_results(
|
||||
columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Query completed successfully, returned {len(result['rows'])} row(s)"
|
||||
)
|
||||
return formatted_output
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}")
|
||||
return f"Error executing query: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}")
|
||||
return f"Unexpected error: {str(e)}"
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
"""Shared tools available across all chatbot implementations."""
|
||||
|
||||
from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import (
|
||||
tavily_search,
|
||||
)
|
||||
|
||||
__all__ = ["tavily_search"]
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
"""Tool for sending streaming status updates to users."""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def send_streaming_message(message: str) -> str:
|
||||
"""Send a streaming message to the user to provide updates during processing.
|
||||
|
||||
Use this tool to send short status updates to the user while you are working
|
||||
on their request. This helps keep the user informed about what you are doing.
|
||||
|
||||
Args:
|
||||
message: A short message describing what you are currently doing.
|
||||
Examples: "Searching database for relevant information..."
|
||||
"Analyzing search results..."
|
||||
"Processing your request..."
|
||||
|
||||
Returns:
|
||||
A confirmation that the message was sent.
|
||||
"""
|
||||
# This tool doesn't actually do anything - it's just for the AI to signal
|
||||
# what it's doing to the frontend via the tool call mechanism
|
||||
return f"Status update sent: {message}"
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
"""Tavily Search Tool for LangGraph.
|
||||
|
||||
This tool provides web search capabilities using the Tavily API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
from modules.connectors.connectorAiTavily import ConnectorWeb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
async def tavily_search(
|
||||
query: Annotated[str, "The search query to look up on the web"],
|
||||
) -> str:
|
||||
"""Search the web using Tavily API.
|
||||
|
||||
Use this tool to search for current information, news, or any web content.
|
||||
The tool returns relevant search results including titles and URLs.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
|
||||
Returns:
|
||||
A formatted string containing search results with titles and URLs
|
||||
"""
|
||||
try:
|
||||
# Create connector instance
|
||||
connector = await ConnectorWeb.create()
|
||||
|
||||
# Perform search with default parameters
|
||||
results = await connector._search(
|
||||
query=query,
|
||||
max_results=5,
|
||||
search_depth="basic",
|
||||
include_answer=True,
|
||||
include_raw_content=False,
|
||||
)
|
||||
|
||||
# Format results
|
||||
if not results:
|
||||
return f"No results found for query: {query}"
|
||||
|
||||
formatted_results = [f"Search results for '{query}':\n"]
|
||||
for i, result in enumerate(results, 1):
|
||||
formatted_results.append(f"{i}. {result.title}")
|
||||
formatted_results.append(f" URL: {result.url}\n")
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tavily_search tool: {str(e)}")
|
||||
return f"Error performing search: {str(e)}"
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""Domain logic for chatbot functionality."""
|
||||
|
|
@ -1,301 +0,0 @@
|
|||
"""Chatbot domain logic with LangGraph integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, AsyncIterator, Any
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatState(BaseModel):
|
||||
"""Represents the state of a chat session."""
|
||||
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
|
||||
|
||||
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
||||
"""Map permission model names to LangChain ChatAnthropic models.
|
||||
|
||||
Args:
|
||||
model_name: The model name from permissions (e.g., "claude_4_5")
|
||||
|
||||
Returns:
|
||||
Configured ChatAnthropic instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the model name is not supported
|
||||
"""
|
||||
# Model name mapping
|
||||
model_mapping = {
|
||||
"claude_4_5": "claude-sonnet-4-5",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
|
||||
anthropic_model = model_mapping.get(model_name)
|
||||
if not anthropic_model:
|
||||
logger.warning(
|
||||
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
|
||||
)
|
||||
anthropic_model = "claude-4-5-sonnet"
|
||||
|
||||
return ChatAnthropic(
|
||||
model=anthropic_model,
|
||||
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
|
||||
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
|
||||
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chatbot:
|
||||
"""Represents a chatbot with LangGraph integration."""
|
||||
|
||||
model: Any
|
||||
memory: Any
|
||||
app: Any = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
context_window_size: int = 100000
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
*,
|
||||
model: Any,
|
||||
memory: Any,
|
||||
system_prompt: str,
|
||||
tools: list,
|
||||
context_window_size: int = 100000,
|
||||
) -> "Chatbot":
|
||||
"""Factory method to create and configure a Chatbot instance.
|
||||
|
||||
Args:
|
||||
model: The chat model to use.
|
||||
memory: The chat memory checkpointer to use.
|
||||
system_prompt: The system prompt to initialize the chatbot.
|
||||
tools: List of LangChain tools the chatbot can use.
|
||||
context_window_size: Maximum tokens for context window.
|
||||
|
||||
Returns:
|
||||
A configured Chatbot instance.
|
||||
"""
|
||||
instance = cls(
|
||||
model=model,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
instance.app = instance._build_app(memory=memory, tools=tools)
|
||||
return instance
|
||||
|
||||
def _build_app(
|
||||
self, *, memory: Any, tools: list
|
||||
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
||||
"""Builds the chatbot application workflow using LangGraph.
|
||||
|
||||
Args:
|
||||
memory: The chat memory checkpointer to use.
|
||||
tools: The list of tools the chatbot can use.
|
||||
|
||||
Returns:
|
||||
A compiled state graph representing the chatbot application.
|
||||
"""
|
||||
llm_with_tools = self.model.bind_tools(tools=tools)
|
||||
|
||||
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Selects a window of messages that fit within the context window size.
|
||||
|
||||
Args:
|
||||
msgs: The list of messages to select from.
|
||||
|
||||
Returns:
|
||||
A list of messages that fit within the context window size.
|
||||
"""
|
||||
|
||||
def approx_counter(items: list[BaseMessage]) -> int:
|
||||
"""Approximate token counter for messages.
|
||||
|
||||
Args:
|
||||
items: List of messages to count tokens for.
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
"""
|
||||
return sum(len(getattr(m, "content", "") or "") for m in items)
|
||||
|
||||
return trim_messages(
|
||||
msgs,
|
||||
strategy="last",
|
||||
token_counter=approx_counter,
|
||||
max_tokens=self.context_window_size,
|
||||
start_on="human",
|
||||
end_on=("human", "tool"),
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
def agent_node(state: ChatState) -> dict:
|
||||
"""Agent node for the chatbot workflow.
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The updated chat state after processing.
|
||||
"""
|
||||
# Select the message window to fit in context (trim if needed)
|
||||
window = select_window(state.messages)
|
||||
|
||||
# Ensure the system prompt is present at the start
|
||||
if not window or not isinstance(window[0], SystemMessage):
|
||||
window = [SystemMessage(content=self.system_prompt)] + window
|
||||
|
||||
# Call the LLM with tools
|
||||
response = llm_with_tools.invoke(window)
|
||||
|
||||
# Return the new state
|
||||
return {"messages": [response]}
|
||||
|
||||
def should_continue(state: ChatState) -> str:
|
||||
"""Determines whether to continue the workflow or end it.
|
||||
|
||||
This conditional edge is called after the agent node to decide
|
||||
whether to continue to the tools node (if the last message contains
|
||||
tool calls) or to end the workflow (if no tool calls are present).
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The next node to transition to ("tools" or END).
|
||||
"""
|
||||
# Get the last message
|
||||
last_message = state.messages[-1]
|
||||
|
||||
# Check if the last message contains tool calls
|
||||
# If so, continue to the tools node; otherwise, end the workflow
|
||||
return "tools" if getattr(last_message, "tool_calls", None) else END
|
||||
|
||||
# Compose the workflow
|
||||
workflow = StateGraph(ChatState)
|
||||
workflow.add_node("agent", agent_node)
|
||||
workflow.add_node("tools", ToolNode(tools=tools))
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue)
|
||||
workflow.add_edge("tools", "agent")
|
||||
return workflow.compile(checkpointer=memory)
|
||||
|
||||
async def chat(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> list[BaseMessage]:
|
||||
"""Processes a chat message and returns the chat history.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: The chat thread ID.
|
||||
|
||||
Returns:
|
||||
The list of messages in the chat history.
|
||||
"""
|
||||
# Set the right thread ID for memory
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
# Single-turn chat (non-streaming)
|
||||
result = await self.app.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}, config=config
|
||||
)
|
||||
|
||||
# Extract and return the messages from the result
|
||||
return result["messages"]
|
||||
|
||||
async def stream_events(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Stream UI-focused events using astream_events v2.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: Logical thread identifier; forwarded in the runnable config so
|
||||
memory and tools are scoped per thread.
|
||||
|
||||
Yields:
|
||||
dict: One of:
|
||||
- ``{"type": "status", "label": str}`` for short progress updates.
|
||||
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
||||
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
||||
- ``{"type": "error", "message": str}`` if an exception occurs.
|
||||
"""
|
||||
# Thread-aware config for LangGraph/LangChain
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
def _is_root(ev: dict) -> bool:
|
||||
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
||||
return not ev.get("parent_ids")
|
||||
|
||||
try:
|
||||
async for event in self.app.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
config=config,
|
||||
version="v2",
|
||||
):
|
||||
etype = event.get("event")
|
||||
ename = event.get("name") or ""
|
||||
edata = event.get("data") or {}
|
||||
|
||||
# Stream human-readable progress via the special send_streaming_message tool
|
||||
if etype == "on_tool_start" and ename == "send_streaming_message":
|
||||
tool_in = edata.get("input") or {}
|
||||
msg = tool_in.get("message")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
yield {"type": "status", "label": msg.strip()}
|
||||
continue
|
||||
|
||||
# Emit the final payload when the root run finishes
|
||||
if etype == "on_chain_end" and _is_root(event):
|
||||
output_obj = edata.get("output")
|
||||
|
||||
# Extract message list from the graph's final output
|
||||
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
||||
output_obj=output_obj
|
||||
)
|
||||
|
||||
# Normalize for the frontend (only user/assistant with text content)
|
||||
chat_history_payload: list[dict] = []
|
||||
for m in final_msgs:
|
||||
if isinstance(m, BaseMessage):
|
||||
d = ChatStreamingHelper.message_to_dict(msg=m)
|
||||
elif isinstance(m, dict):
|
||||
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
||||
else:
|
||||
continue
|
||||
if d.get("role") in ("user", "assistant") and d.get("content"):
|
||||
chat_history_payload.append(d)
|
||||
|
||||
yield {
|
||||
"type": "final",
|
||||
"response": {
|
||||
"thread": chat_id,
|
||||
"chat_history": chat_history_payload,
|
||||
},
|
||||
}
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
# Emit a single error envelope and end the stream
|
||||
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
|
||||
yield {"type": "error", "message": f"Error processing request: {exc}"}
|
||||
|
|
@ -1,239 +0,0 @@
|
|||
"""Streaming helper utilities for chat message processing and normalization."""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
Role = Literal["user", "assistant", "system", "tool"]
|
||||
|
||||
|
||||
class ChatStreamingHelper:
|
||||
"""Pure helper methods for streaming and message normalization.
|
||||
|
||||
This class provides static utility methods for converting between different
|
||||
message formats, extracting content, and normalizing message structures
|
||||
for streaming chat applications.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def role_from_message(*, msg: BaseMessage) -> Role:
|
||||
"""Extract the role from a BaseMessage instance.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to extract the role from.
|
||||
|
||||
Returns:
|
||||
The role as a string literal: "user", "assistant", "system", or "tool".
|
||||
Defaults to "assistant" if the message type is not recognized.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello")
|
||||
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
||||
'user'
|
||||
"""
|
||||
if isinstance(msg, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(msg, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(msg, SystemMessage):
|
||||
return "system"
|
||||
if isinstance(msg, ToolMessage):
|
||||
return "tool"
|
||||
return getattr(msg, "role", "assistant")
|
||||
|
||||
@staticmethod
|
||||
def flatten_content(*, content: Any) -> str:
|
||||
"""Convert complex content structures to plain text.
|
||||
|
||||
This method handles various content formats including strings, lists of
|
||||
content parts, and dictionaries with text fields. It's designed to
|
||||
normalize content from different message sources into a consistent
|
||||
plain text format.
|
||||
|
||||
Args:
|
||||
content: The content to flatten. Can be:
|
||||
- str: Returned as-is after stripping whitespace
|
||||
- list: Each item processed and joined with newlines
|
||||
- dict: Text extracted from "text" or "content" fields
|
||||
- None: Returns empty string
|
||||
- Any other type: Converted to string
|
||||
|
||||
Returns:
|
||||
The flattened content as a plain text string with whitespace stripped.
|
||||
|
||||
Examples:
|
||||
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Hello
|
||||
nworld'
|
||||
|
||||
>>> content = {"text": "Simple message"}
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Simple message'
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if "text" in part and isinstance(part["text"], str):
|
||||
parts.append(part["text"])
|
||||
elif part.get("type") == "text" and isinstance(
|
||||
part.get("text"), str
|
||||
):
|
||||
parts.append(part["text"])
|
||||
elif "content" in part and isinstance(part["content"], str):
|
||||
parts.append(part["content"])
|
||||
else:
|
||||
# Fallback for unknown dictionary structures
|
||||
val = part.get("value")
|
||||
if isinstance(val, str):
|
||||
parts.append(val)
|
||||
else:
|
||||
parts.append(str(part))
|
||||
return "\n".join(p.strip() for p in parts if p is not None)
|
||||
if isinstance(content, dict):
|
||||
if "text" in content and isinstance(content["text"], str):
|
||||
return content["text"].strip()
|
||||
if "content" in content and isinstance(content["content"], str):
|
||||
return content["content"].strip()
|
||||
return str(content).strip()
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
||||
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
||||
|
||||
This method normalizes BaseMessage instances into a consistent dictionary
|
||||
format suitable for JSON serialization and streaming to clients.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to convert.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello there")
|
||||
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello there'
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
||||
"content": ChatStreamingHelper.flatten_content(
|
||||
content=getattr(msg, "content", "")
|
||||
),
|
||||
}
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload["tool_calls"] = tool_calls
|
||||
name = getattr(msg, "name", None)
|
||||
if name:
|
||||
payload["name"] = name
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a dictionary-shaped message to a normalized dictionary.
|
||||
|
||||
This method handles messages that come from serialized state and are
|
||||
represented as dictionaries rather than BaseMessage instances. It
|
||||
normalizes various dictionary formats into a consistent structure.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-shaped message to convert. Expected to contain
|
||||
fields like "role", "type", "content", "text", etc.
|
||||
|
||||
Returns:
|
||||
A normalized dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> obj = {"type": "human", "content": "Hello"}
|
||||
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello'
|
||||
"""
|
||||
role: Optional[str] = obj.get("role")
|
||||
if not role:
|
||||
# Handle alternative type field mappings
|
||||
typ = obj.get("type")
|
||||
if typ in ("human", "user"):
|
||||
role = "user"
|
||||
elif typ in ("ai", "assistant"):
|
||||
role = "assistant"
|
||||
elif typ in ("system",):
|
||||
role = "system"
|
||||
elif typ in ("tool", "function"):
|
||||
role = "tool"
|
||||
|
||||
content = obj.get("content")
|
||||
if content is None and "text" in obj:
|
||||
content = obj["text"]
|
||||
|
||||
out: Dict[str, Any] = {
|
||||
"role": role or "assistant",
|
||||
"content": ChatStreamingHelper.flatten_content(content=content),
|
||||
}
|
||||
if "tool_calls" in obj:
|
||||
out["tool_calls"] = obj["tool_calls"]
|
||||
if obj.get("name"):
|
||||
out["name"] = obj["name"]
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
||||
"""Extract messages from LangGraph output objects.
|
||||
|
||||
This method handles various output formats from LangGraph execution,
|
||||
extracting the messages list from different possible structures.
|
||||
|
||||
Args:
|
||||
output_obj: The output object from LangGraph execution. Can be:
|
||||
- An object with a "messages" attribute
|
||||
- A dictionary with a "messages" key
|
||||
- Any other object (returns empty list)
|
||||
|
||||
Returns:
|
||||
A list of extracted messages, or an empty list if no messages
|
||||
are found or if the output object is None.
|
||||
|
||||
Examples:
|
||||
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
||||
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
||||
>>> len(messages)
|
||||
1
|
||||
"""
|
||||
if output_obj is None:
|
||||
return []
|
||||
|
||||
# Try to parse dicts first
|
||||
if isinstance(output_obj, dict):
|
||||
msgs = output_obj.get("messages")
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
|
||||
# Then try to get messages attribute
|
||||
msgs = getattr(output_obj, "messages", None)
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,197 +0,0 @@
|
|||
from typing import AsyncIterator
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
# Tools Table
|
||||
class Tool(Base):
|
||||
"""Available chatbot tools.
|
||||
|
||||
Stores information about all available tools that can be assigned to users.
|
||||
Each tool has a unique tool_id that corresponds to the registry tool_id.
|
||||
"""
|
||||
|
||||
__tablename__ = "tools"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(1000), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
date_created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# User-Tool Mapping Table
|
||||
class UserToolMapping(Base):
|
||||
"""Mapping of users to their available tools.
|
||||
|
||||
Many-to-many relationship between users and tools.
|
||||
- One user can have multiple tools
|
||||
- One tool can be assigned to multiple users
|
||||
|
||||
The combination of user_id and tool_id is unique.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_tools"
|
||||
__table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
date_granted: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# User Thread Mapping Table
|
||||
class UserThreadMapping(Base):
|
||||
"""Mapping of users to their chat threads.
|
||||
|
||||
Used to keep track of which user owns which chat thread.
|
||||
Also stores meta data like thread name.
|
||||
|
||||
1:N relationship between user and thread. A thread belongs to exactly one user.
|
||||
A user can have multiple threads.
|
||||
Thread_id is unique in the table.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_threads"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
thread_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
date_created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# Dependency that pulls the sessionmaker off app.state
|
||||
# This is set in app.py on startup in @asynccontextmanager
|
||||
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
|
||||
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
SessionLocal: async_sessionmaker[AsyncSession] = (
|
||||
request.app.state.checkpoint_sessionmaker
|
||||
)
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Optional helper to init tables at startup (demo only)
|
||||
async def init_models(engine) -> None:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def sync_tools_from_registry(session: AsyncSession) -> None:
|
||||
"""Sync tools from tool registry to database.
|
||||
|
||||
This function:
|
||||
- Adds new tools from the registry to the database
|
||||
- Updates existing tools with current registry information
|
||||
- Marks tools not present in the registry as inactive
|
||||
|
||||
Should be called on application startup after database initialization.
|
||||
|
||||
Args:
|
||||
session: Active database session
|
||||
"""
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Syncing tools from registry to database...")
|
||||
|
||||
# Get all tools from the registry
|
||||
registry = get_registry()
|
||||
registry_tools = registry.get_all_tools()
|
||||
|
||||
# Create a set of tool_ids from the registry
|
||||
registry_tool_ids = {tool.tool_id for tool in registry_tools}
|
||||
|
||||
logger.info(f"Found {len(registry_tools)} tools in registry")
|
||||
|
||||
# Get all existing tools from the database
|
||||
result = await session.execute(select(Tool))
|
||||
db_tools = result.scalars().all()
|
||||
db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools}
|
||||
|
||||
logger.info(f"Found {len(db_tools)} tools in database")
|
||||
|
||||
# Track changes
|
||||
added_count = 0
|
||||
updated_count = 0
|
||||
deactivated_count = 0
|
||||
|
||||
# Sync tools from registry to database
|
||||
for registry_tool in registry_tools:
|
||||
if registry_tool.tool_id in db_tools_by_tool_id:
|
||||
# Tool exists - update it
|
||||
# Preserve label and description (user-editable fields)
|
||||
db_tool = db_tools_by_tool_id[registry_tool.tool_id]
|
||||
db_tool.name = registry_tool.name
|
||||
db_tool.category = registry_tool.category
|
||||
db_tool.is_active = True
|
||||
db_tool.date_updated = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
logger.debug(f"Updated tool: {registry_tool.tool_id}")
|
||||
else:
|
||||
# Tool doesn't exist - create it
|
||||
new_tool = Tool(
|
||||
tool_id=registry_tool.tool_id,
|
||||
name=registry_tool.name,
|
||||
label=registry_tool.tool_id, # Use tool_id as label per spec
|
||||
category=registry_tool.category,
|
||||
description=registry_tool.description or "",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(new_tool)
|
||||
added_count += 1
|
||||
logger.debug(f"Added new tool: {registry_tool.tool_id}")
|
||||
|
||||
# Mark tools not in registry as inactive
|
||||
for db_tool in db_tools:
|
||||
if db_tool.tool_id not in registry_tool_ids and db_tool.is_active:
|
||||
db_tool.is_active = False
|
||||
db_tool.date_updated = datetime.now(timezone.utc)
|
||||
deactivated_count += 1
|
||||
logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}")
|
||||
|
||||
logger.info(
|
||||
f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated"
|
||||
)
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
"""PostgreSQL checkpointer utilities for LangGraph memory."""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply)
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
except RuntimeError:
|
||||
pass # Already set
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.rows import dict_row
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpointer instance
|
||||
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
|
||||
_connection_pool: Optional[AsyncConnectionPool] = None
|
||||
|
||||
|
||||
async def initialize_checkpointer() -> None:
|
||||
"""Initialize the PostgreSQL checkpointer for LangGraph.
|
||||
|
||||
This should be called during application startup.
|
||||
Creates a connection pool and PostgresSaver instance.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _checkpointer_instance is not None:
|
||||
logger.info("Checkpointer already initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
# Get database configuration from environment
|
||||
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
|
||||
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
|
||||
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
|
||||
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
|
||||
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
|
||||
|
||||
# Build connection string
|
||||
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
# Create async connection pool
|
||||
_connection_pool = AsyncConnectionPool(
|
||||
conninfo=connection_string,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
kwargs={"autocommit": True, "row_factory": dict_row},
|
||||
)
|
||||
|
||||
# Initialize the connection pool
|
||||
await _connection_pool.open()
|
||||
|
||||
# Create AsyncPostgresSaver with the pool
|
||||
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
|
||||
|
||||
# Setup the checkpointer (creates tables if needed)
|
||||
await _checkpointer_instance.setup()
|
||||
|
||||
logger.info("PostgreSQL checkpointer initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""Close the checkpointer and connection pool.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _connection_pool is not None:
|
||||
try:
|
||||
await _connection_pool.close()
|
||||
logger.info("PostgreSQL checkpointer connection pool closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
|
||||
|
||||
_checkpointer_instance = None
|
||||
_connection_pool = None
|
||||
|
||||
|
||||
def get_checkpointer() -> AsyncPostgresSaver:
|
||||
"""Get the global PostgreSQL checkpointer instance.
|
||||
|
||||
Returns:
|
||||
The initialized AsyncPostgresSaver instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If checkpointer is not initialized
|
||||
"""
|
||||
if _checkpointer_instance is None:
|
||||
raise RuntimeError(
|
||||
"PostgreSQL checkpointer not initialized. "
|
||||
"Call initialize_checkpointer() during application startup."
|
||||
)
|
||||
return _checkpointer_instance
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
"""Mock permissions module for chatbot access control.
|
||||
|
||||
This module provides mock permission functions that will be replaced
|
||||
with actual database-driven permissions in the future.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
|
||||
|
||||
# TODO: Replace these mock implementations with actual database queries
|
||||
|
||||
|
||||
def get_chatbot_tools(*, user_id: str) -> list[str]:
|
||||
"""Get list of tool IDs that the chatbot can use for a given user."""
|
||||
registry = get_registry()
|
||||
return registry.list_tool_ids()
|
||||
|
||||
|
||||
def get_chatbot_model(*, user_id: str) -> str:
|
||||
"""Gets the chatbot model(s) a user is allowed to use."""
|
||||
return "claude_4_5"
|
||||
|
||||
|
||||
def get_system_prompt(*, user_id: str) -> str:
|
||||
"""Get the system prompt for a user's chatbot session.
|
||||
|
||||
This is a mock implementation that returns a generic prompt with today's date.
|
||||
In production, this will query the database for user-specific or role-specific prompts.
|
||||
|
||||
Args:
|
||||
user_id: The unique identifier of the user
|
||||
|
||||
Returns:
|
||||
The system prompt string with the current date
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
return f"You're a smart assistant. Today is {current_date}"
|
||||
|
|
@ -1,305 +0,0 @@
|
|||
"""Tool registry for auto-discovering and managing chatbot tools.
|
||||
|
||||
This module provides a central registry that automatically discovers all tools
|
||||
in the chatbotTools directory structure and provides methods to query them.
|
||||
The registry is built in-memory at startup and does not require a database.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""Metadata about a discovered chatbot tool.
|
||||
|
||||
Attributes:
|
||||
tool_id: Unique identifier (e.g., 'shared.tavily_search')
|
||||
name: Function name of the tool
|
||||
category: Category of the tool ('shared' or 'customer')
|
||||
description: Tool description from docstring
|
||||
tool_instance: The actual LangChain tool instance
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
|
||||
tool_id: str
|
||||
name: str
|
||||
category: str
|
||||
description: str
|
||||
tool_instance: BaseTool
|
||||
module_path: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a pretty-printed string representation for logging."""
|
||||
return (
|
||||
f"ToolMetadata(\n"
|
||||
f" tool_id='{self.tool_id}',\n"
|
||||
f" name='{self.name}',\n"
|
||||
f" category='{self.category}',\n"
|
||||
f" description='{self.description}',\n"
|
||||
f" module_path='{self.module_path}'\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Central registry for all chatbot tools.
|
||||
|
||||
This class discovers and catalogs all tools decorated with @tool in the
|
||||
chatbotTools directory structure. Tools are automatically discovered at
|
||||
initialization by scanning the filesystem.
|
||||
|
||||
The registry provides methods to query tools by ID, category, or get all tools.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty tool registry."""
|
||||
self._tools: Dict[str, ToolMetadata] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Discover and register all tools from the chatbotTools directory.
|
||||
|
||||
This method scans both sharedTools and customerTools directories,
|
||||
imports all tool*.py modules, and extracts functions decorated with @tool.
|
||||
|
||||
This method is idempotent - calling it multiple times has no effect
|
||||
after the first initialization.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("Tool registry already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.info("Initializing tool registry...")
|
||||
|
||||
# Get base path to chatbotTools directory
|
||||
base_path = Path(__file__).parent.parent / "chatbotTools"
|
||||
|
||||
if not base_path.exists():
|
||||
logger.warning(f"chatbotTools directory not found at {base_path}")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
# Discover tools in each category
|
||||
self._discover_category(
|
||||
category_path=base_path / "sharedTools", category="shared"
|
||||
)
|
||||
self._discover_category(
|
||||
category_path=base_path / "customerTools", category="customer"
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
|
||||
|
||||
def _discover_category(self, *, category_path: Path, category: str) -> None:
|
||||
"""Discover all tools in a specific category directory.
|
||||
|
||||
Args:
|
||||
category_path: Path to the category directory (sharedTools or customerTools)
|
||||
category: Category name ('shared' or 'customer')
|
||||
"""
|
||||
if not category_path.exists():
|
||||
logger.warning(f"Category directory not found: {category_path}")
|
||||
return
|
||||
|
||||
logger.debug(f"Discovering tools in category: {category}")
|
||||
|
||||
# Find all tool*.py files (excluding __init__.py)
|
||||
tool_files = [
|
||||
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
|
||||
]
|
||||
|
||||
for tool_file in tool_files:
|
||||
self._import_and_register_tools(
|
||||
tool_file=tool_file, category=category, category_path=category_path
|
||||
)
|
||||
|
||||
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
|
||||
|
||||
def _import_and_register_tools(
|
||||
self, *, tool_file: Path, category: str, category_path: Path
|
||||
) -> None:
|
||||
"""Import a tool module and register all discovered tools.
|
||||
|
||||
Args:
|
||||
tool_file: Path to the tool Python file
|
||||
category: Category name ('shared' or 'customer')
|
||||
category_path: Path to the category directory
|
||||
"""
|
||||
# Construct module name
|
||||
module_name = (
|
||||
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Find all BaseTool instances in the module
|
||||
tools_found = 0
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if isinstance(obj, BaseTool):
|
||||
self._register_tool(
|
||||
tool_instance=obj,
|
||||
name=name,
|
||||
category=category,
|
||||
module_path=module_name,
|
||||
)
|
||||
tools_found += 1
|
||||
|
||||
if tools_found == 0:
|
||||
logger.warning(f"No tools found in {module_name}")
|
||||
else:
|
||||
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"Import error loading tools from {module_name}: {str(e)}. "
|
||||
f"This tool will not be available."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
|
||||
def _register_tool(
|
||||
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
|
||||
) -> None:
|
||||
"""Register a single tool in the registry.
|
||||
|
||||
Args:
|
||||
tool_instance: The LangChain tool instance
|
||||
name: Function name of the tool
|
||||
category: Category name ('shared' or 'customer')
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
tool_id = f"{category}.{name}"
|
||||
|
||||
# Check for duplicate tool IDs
|
||||
if tool_id in self._tools:
|
||||
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
|
||||
|
||||
metadata = ToolMetadata(
|
||||
tool_id=tool_id,
|
||||
name=name,
|
||||
category=category,
|
||||
description=tool_instance.description or "",
|
||||
tool_instance=tool_instance,
|
||||
module_path=module_path,
|
||||
)
|
||||
|
||||
self._tools[tool_id] = metadata
|
||||
logger.debug(f"Registered tool: {tool_id}")
|
||||
|
||||
def get_all_tools(self) -> List[ToolMetadata]:
|
||||
"""Get all registered tools.
|
||||
|
||||
Returns:
|
||||
List of all tool metadata objects
|
||||
"""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
|
||||
"""Get a specific tool by its ID.
|
||||
|
||||
Args:
|
||||
tool_id: The tool identifier (e.g., 'shared.tavily_search')
|
||||
|
||||
Returns:
|
||||
Tool metadata if found, None otherwise
|
||||
"""
|
||||
return self._tools.get(tool_id)
|
||||
|
||||
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
|
||||
"""Get all tools in a specific category.
|
||||
|
||||
Args:
|
||||
category: Category name ('shared' or 'customer')
|
||||
|
||||
Returns:
|
||||
List of tool metadata for the specified category
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.category == category]
|
||||
|
||||
def list_tool_ids(self) -> List[str]:
|
||||
"""Get a list of all registered tool IDs.
|
||||
|
||||
Returns:
|
||||
List of tool ID strings
|
||||
"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
|
||||
"""Get actual tool instances for a list of tool IDs.
|
||||
|
||||
This is useful for filtering tools based on user permissions.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs to retrieve
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances for the specified IDs
|
||||
"""
|
||||
instances = []
|
||||
for tool_id in tool_ids:
|
||||
metadata = self.get_tool(tool_id=tool_id)
|
||||
if metadata:
|
||||
instances.append(metadata.tool_instance)
|
||||
else:
|
||||
logger.warning(f"Tool ID not found in registry: {tool_id}")
|
||||
return instances
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the registry has been initialized.
|
||||
|
||||
Returns:
|
||||
True if initialized, False otherwise
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ToolRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global tool registry instance.
|
||||
|
||||
This function ensures the registry is initialized on first access.
|
||||
Subsequent calls return the same instance.
|
||||
|
||||
Returns:
|
||||
The global ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
|
||||
if _registry is None:
|
||||
_registry = ToolRegistry()
|
||||
|
||||
if not _registry.is_initialized:
|
||||
_registry.initialize()
|
||||
|
||||
return _registry
|
||||
|
||||
|
||||
def reinitialize_registry() -> ToolRegistry:
|
||||
"""Force reinitialize the tool registry.
|
||||
|
||||
This is useful for testing or when tools are added dynamically.
|
||||
|
||||
Returns:
|
||||
The reinitialized ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
_registry = ToolRegistry()
|
||||
_registry.initialize()
|
||||
return _registry
|
||||
|
|
@ -13,14 +13,15 @@ async def start() -> None:
|
|||
from modules.features.syncDelta import mainSyncDelta
|
||||
mainSyncDelta.startSyncManager(eventUser)
|
||||
|
||||
# Feature ChatBot
|
||||
from modules.features.chatBot.mainChatBot import start as startChatBot
|
||||
await startChatBot()
|
||||
# Feature ...
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
async def stop() -> None:
|
||||
""" Stop feature triggers and background managers """
|
||||
|
||||
# Feature ChatBot
|
||||
from modules.features.chatBot.mainChatBot import stop as stopChatBot
|
||||
await stopChatBot()
|
||||
# Feature ...
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -1,653 +0,0 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import logging
|
||||
|
||||
from modules.datamodels.datamodelUam import User, UserPrivilege
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from modules.features.chatBot.subChatbotDatabase import get_async_db_session
|
||||
from modules.features.chatBot.mainChatBot import (
|
||||
get_or_create_thread_for_user,
|
||||
)
|
||||
from modules.datamodels.datamodelChatbot import (
|
||||
ChatMessageRequest,
|
||||
MessageItem,
|
||||
ChatMessageResponse,
|
||||
ThreadSummary,
|
||||
ThreadListResponse,
|
||||
ThreadDetail,
|
||||
RenameThreadRequest,
|
||||
DeleteResponse,
|
||||
ToolListResponse,
|
||||
ToolInfo,
|
||||
GrantToolRequest,
|
||||
GrantToolResponse,
|
||||
RevokeToolRequest,
|
||||
RevokeToolResponse,
|
||||
UpdateToolRequest,
|
||||
UpdateToolResponse,
|
||||
)
|
||||
from modules.features.chatBot import mainChatBot as chat_service
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/chatbot",
|
||||
tags=["Chatbot"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
# --- Actual endpoints for chatbot ---
|
||||
|
||||
|
||||
@router.post("/message/stream")
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message_stream(
|
||||
*,
|
||||
request: Request,
|
||||
message_request: ChatMessageRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Post a message to a chat thread with streaming progress updates.
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
chat_service.post_message_stream(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/message", response_model=ChatMessageResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message(
|
||||
*,
|
||||
request: Request,
|
||||
message_request: ChatMessageRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ChatMessageResponse:
|
||||
"""
|
||||
Post a message to a chat thread and get assistant response (non-streaming).
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
For streaming updates, use the /message/stream endpoint instead.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
||||
|
||||
response = await chat_service.post_message(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads", response_model=ThreadListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_threads(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadListResponse:
|
||||
"""
|
||||
Get all chat threads for the current user.
|
||||
"""
|
||||
try:
|
||||
# Get all threads for the current user
|
||||
threads = await chat_service.get_all_threads_for_user(
|
||||
user=currentUser, session=session
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved {len(threads)} threads")
|
||||
|
||||
return ThreadListResponse(threads=threads)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_thread_by_id(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadDetail:
|
||||
"""
|
||||
Get a specific chat thread with all its messages from LangGraph checkpointer.
|
||||
"""
|
||||
try:
|
||||
thread_detail = await chat_service.get_thread_detail_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
||||
return thread_detail
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def rename_thread(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
rename_request: RenameThreadRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Rename a chat thread.
|
||||
"""
|
||||
try:
|
||||
await chat_service.update_thread_name(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
new_thread_name=rename_request.new_name,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'"
|
||||
)
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully renamed",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_thread(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Delete a chat thread and all its associated data from both LangGraph and database.
|
||||
"""
|
||||
try:
|
||||
await chat_service.delete_thread_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} deleted thread {thread_id}")
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully deleted",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
# Tool Management Endpoints
|
||||
|
||||
|
||||
@router.get("/tools", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_tools(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all available chatbot tools.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can view tools",
|
||||
)
|
||||
|
||||
# Get all tools from service
|
||||
tools_data = await chat_service.get_all_tools(session=session)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved {len(tools)} tools")
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tools/grant", response_model=GrantToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def grant_tool_to_user(
|
||||
*,
|
||||
request: Request,
|
||||
grant_request: GrantToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> GrantToolResponse:
|
||||
"""
|
||||
Grant a tool to a user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can grant tools",
|
||||
)
|
||||
|
||||
# Grant the tool
|
||||
await chat_service.grant_tool_to_user(
|
||||
user_id=grant_request.user_id,
|
||||
tool_id=grant_request.tool_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}"
|
||||
)
|
||||
|
||||
return GrantToolResponse(
|
||||
message=f"Tool successfully granted to user {grant_request.user_id}",
|
||||
user_id=grant_request.user_id,
|
||||
tool_id=grant_request.tool_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tools/revoke", response_model=RevokeToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def revoke_tool_from_user(
|
||||
*,
|
||||
request: Request,
|
||||
revoke_request: RevokeToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> RevokeToolResponse:
|
||||
"""
|
||||
Revoke a tool from a user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can revoke tools",
|
||||
)
|
||||
|
||||
# Revoke the tool
|
||||
await chat_service.revoke_tool_from_user(
|
||||
user_id=revoke_request.user_id,
|
||||
tool_id=revoke_request.tool_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}"
|
||||
)
|
||||
|
||||
return RevokeToolResponse(
|
||||
message=f"Tool successfully revoked from user {revoke_request.user_id}",
|
||||
user_id=revoke_request.user_id,
|
||||
tool_id=revoke_request.tool_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_tool(
|
||||
*,
|
||||
request: Request,
|
||||
tool_id: str,
|
||||
update_request: UpdateToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> UpdateToolResponse:
|
||||
"""
|
||||
Update a tool's label and/or description.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can update tools",
|
||||
)
|
||||
|
||||
# Update the tool
|
||||
updated_fields = await chat_service.update_tool(
|
||||
tool_id=tool_id,
|
||||
label=update_request.label,
|
||||
description=update_request.description,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}"
|
||||
)
|
||||
|
||||
return UpdateToolResponse(
|
||||
message="Tool successfully updated",
|
||||
tool_id=tool_id,
|
||||
updated_fields=updated_fields,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/user/{user_id}", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_tools_for_specific_user(
|
||||
*,
|
||||
request: Request,
|
||||
user_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all tools granted to a specific user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can view user tools",
|
||||
)
|
||||
|
||||
# Get tools for the specified user
|
||||
tools_data = await chat_service.get_tools_for_user(
|
||||
user_id=user_id, session=session
|
||||
)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}"
|
||||
)
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/me", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_my_tools(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all tools the current user has access to.
|
||||
"""
|
||||
try:
|
||||
# Get tools for the current user
|
||||
tools_data = await chat_service.get_tools_for_user(
|
||||
user_id=currentUser.id, session=session
|
||||
)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} retrieved {len(tools)} tools for themselves"
|
||||
)
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Dict, Any, List, Optional
|
||||
import logging
|
||||
from datetime import datetime, UTC
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
|
|
|
|||
Loading…
Reference in a new issue