removed chatbot on langgraph platform

This commit is contained in:
ValueOn AG 2025-10-20 22:01:03 +02:00
parent 711b8bc50d
commit f05d958213
20 changed files with 7 additions and 3922 deletions

2
app.py
View file

@ -422,5 +422,3 @@ app.include_router(voiceGoogleRouter)
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
app.include_router(adminSecurityRouter)
# from modules.routes.routeChatbot import router as chatbotRouter
# app.include_router(chatbotRouter)

View file

@ -1,216 +0,0 @@
"""Chatbot API models for request/response handling."""
from typing import List, Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import register_model_labels, ModelMixin
# Chatbot API Models
class MessageItem(BaseModel, ModelMixin):
"""Individual message in a thread"""
role: str = Field(..., description="Message role (user or assistant)")
content: str = Field(..., description="Message content")
class ChatMessageRequest(BaseModel, ModelMixin):
"""Request model for posting a chat message"""
thread_id: Optional[str] = Field(
None, description="Thread ID (creates new thread if not provided)"
)
message: str = Field(..., description="User message content")
tools: Optional[List[str]] = Field(
None,
description="List of tool IDs to use. If not provided, all user's tools will be used",
)
class ChatMessageResponse(BaseModel, ModelMixin):
"""Response model for posting a chat message"""
thread_id: str = Field(..., description="Thread ID")
messages: List[MessageItem] = Field(..., description="All messages in thread")
class ThreadSummary(BaseModel, ModelMixin):
"""Summary of a chat thread for list view"""
thread_id: str = Field(..., description="Thread ID")
thread_name: str = Field(..., description="Thread name")
date_created: float = Field(..., description="Thread creation timestamp")
date_updated: float = Field(..., description="Thread last updated timestamp")
class ThreadListResponse(BaseModel, ModelMixin):
"""Response model for listing all threads"""
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
class ThreadDetail(BaseModel, ModelMixin):
"""Detailed view of a single thread"""
thread_id: str = Field(..., description="Thread ID")
date_created: float = Field(..., description="Thread creation timestamp")
date_updated: float = Field(..., description="Thread last updated timestamp")
messages: List[MessageItem] = Field(
..., description="All messages in chronological order"
)
class RenameThreadRequest(BaseModel, ModelMixin):
"""Request model for renaming a thread"""
new_name: str = Field(..., description="New name for the thread")
class DeleteResponse(BaseModel, ModelMixin):
"""Response model for delete operations"""
message: str = Field(..., description="Confirmation message")
thread_id: str = Field(..., description="Deleted thread ID")
# Tool Management Models
class ToolInfo(BaseModel, ModelMixin):
"""Information about a chatbot tool"""
id: str = Field(..., description="Tool UUID")
tool_id: str = Field(
..., description="Tool identifier (e.g., 'shared.tavily_search')"
)
name: str = Field(..., description="Tool function name")
label: str = Field(..., description="Display label for the tool")
category: str = Field(..., description="Tool category (shared or customer)")
description: str = Field(..., description="Tool description")
is_active: bool = Field(..., description="Whether the tool is active")
date_created: float = Field(..., description="Creation timestamp")
date_updated: float = Field(..., description="Last update timestamp")
class ToolListResponse(BaseModel, ModelMixin):
"""Response model for listing all tools"""
tools: List[ToolInfo] = Field(..., description="List of available tools")
class GrantToolRequest(BaseModel, ModelMixin):
"""Request model for granting a tool to a user"""
user_id: str = Field(..., description="User ID to grant the tool to")
tool_id: str = Field(..., description="Tool UUID from tools table")
class GrantToolResponse(BaseModel, ModelMixin):
"""Response model after granting a tool"""
message: str = Field(..., description="Confirmation message")
user_id: str = Field(..., description="User ID")
tool_id: str = Field(..., description="Tool UUID")
class RevokeToolRequest(BaseModel, ModelMixin):
"""Request model for revoking a tool from a user"""
user_id: str = Field(..., description="User ID to revoke the tool from")
tool_id: str = Field(..., description="Tool UUID from tools table")
class RevokeToolResponse(BaseModel, ModelMixin):
"""Response model after revoking a tool"""
message: str = Field(..., description="Confirmation message")
user_id: str = Field(..., description="User ID")
tool_id: str = Field(..., description="Tool UUID")
class UpdateToolRequest(BaseModel, ModelMixin):
"""Request model for updating a tool's label and description"""
label: Optional[str] = Field(None, description="New label for the tool")
description: Optional[str] = Field(None, description="New description for the tool")
class UpdateToolResponse(BaseModel, ModelMixin):
"""Response model after updating a tool"""
message: str = Field(..., description="Confirmation message")
tool_id: str = Field(..., description="Tool UUID")
updated_fields: List[str] = Field(..., description="List of updated field names")
# Register model labels for internationalization
register_model_labels(
"MessageItem",
{"en": "Message Item", "fr": "Élément de message"},
{
"role": {"en": "Role", "fr": "Rôle"},
"content": {"en": "Content", "fr": "Contenu"},
},
)
register_model_labels(
"ChatMessageRequest",
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"message": {"en": "Message", "fr": "Message"},
},
)
register_model_labels(
"ChatMessageResponse",
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"messages": {"en": "Messages", "fr": "Messages"},
},
)
register_model_labels(
"ThreadSummary",
{"en": "Thread Summary", "fr": "Résumé du fil"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"thread_name": {"en": "Thread Name", "fr": "Nom du fil"},
"date_created": {"en": "Date Created", "fr": "Date de création"},
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
},
)
register_model_labels(
"ThreadListResponse",
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
{
"threads": {"en": "Threads", "fr": "Fils"},
},
)
register_model_labels(
"ThreadDetail",
{"en": "Thread Detail", "fr": "Détail du fil"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"date_created": {"en": "Date Created", "fr": "Date de création"},
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
"messages": {"en": "Messages", "fr": "Messages"},
},
)
register_model_labels(
"RenameThreadRequest",
{"en": "Rename Thread Request", "fr": "Demande de renommage de fil"},
{
"new_name": {"en": "New Name", "fr": "Nouveau nom"},
},
)
register_model_labels(
"DeleteResponse",
{"en": "Delete Response", "fr": "Réponse de suppression"},
{
"message": {"en": "Message", "fr": "Message"},
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
},
)

View file

@ -1 +0,0 @@
"""Contains all tools available for the chatbot to use."""

View file

@ -1 +0,0 @@
"""Tools that are shared between multiple customers go here."""

View file

@ -1,208 +0,0 @@
"""Althaus Database Query Tool for LangGraph.
This tool provides database query capabilities for the Althaus database
via an external REST API. Only SELECT queries are allowed.
"""
import logging
import asyncio
import re
from typing import Annotated
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
async def _mock_api_call(*, sql_query: str) -> dict:
"""Mock the external REST API call to Althaus database.
Args:
sql_query: The SQL SELECT query to execute
Returns:
A dictionary containing the query results with columns and rows
"""
# Simulate network delay
await asyncio.sleep(0.5)
# Mock response data based on common query patterns
if "users" in sql_query.lower():
return {
"columns": ["id", "username", "email", "created_at"],
"rows": [
[1, "john_doe", "john@example.com", "2024-01-15"],
[2, "jane_smith", "jane@example.com", "2024-02-20"],
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
],
"row_count": 3,
}
elif "products" in sql_query.lower():
return {
"columns": ["product_id", "name", "price", "stock"],
"rows": [
[101, "Widget A", 29.99, 150],
[102, "Widget B", 39.99, 75],
[103, "Widget C", 19.99, 200],
],
"row_count": 3,
}
elif "orders" in sql_query.lower():
return {
"columns": ["order_id", "customer_id", "total", "status"],
"rows": [
[5001, 1, 129.99, "completed"],
[5002, 2, 89.50, "pending"],
[5003, 1, 199.99, "shipped"],
],
"row_count": 3,
}
else:
# Generic response for other queries
return {
"columns": ["id", "value", "description"],
"rows": [
[1, "Sample 1", "First sample entry"],
[2, "Sample 2", "Second sample entry"],
],
"row_count": 2,
}
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
"""Validate that the query is a SELECT statement only.
Args:
sql_query: The SQL query to validate
Returns:
A tuple of (is_valid, error_message)
"""
# Remove leading/trailing whitespace and convert to lowercase for checking
normalized_query = sql_query.strip().lower()
# Check if query starts with SELECT
if not normalized_query.startswith("select"):
return False, "Query must be a SELECT statement"
# Check for dangerous keywords that should not be in a SELECT query
dangerous_keywords = [
"insert",
"update",
"delete",
"drop",
"create",
"alter",
"truncate",
"grant",
"revoke",
"exec",
"execute",
]
for keyword in dangerous_keywords:
# Use word boundary to match whole words only
if re.search(rf"\b{keyword}\b", normalized_query):
return False, f"Query contains forbidden keyword: {keyword.upper()}"
return True, ""
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
"""Format query results into a readable string.
Args:
columns: List of column names
rows: List of row data
row_count: Total number of rows
Returns:
Formatted string representation of the results
"""
if row_count == 0:
return "Query executed successfully but returned no results."
# Calculate column widths
col_widths = [len(str(col)) for col in columns]
for row in rows:
for i, cell in enumerate(row):
col_widths[i] = max(col_widths[i], len(str(cell)))
# Build header
header_parts = []
for col, width in zip(columns, col_widths):
header_parts.append(str(col).ljust(width))
header = " | ".join(header_parts)
separator = "-" * len(header)
# Build rows
row_lines = []
for row in rows:
row_parts = []
for cell, width in zip(row, col_widths):
row_parts.append(str(cell).ljust(width))
row_lines.append(" | ".join(row_parts))
# Combine all parts
result_parts = [
f"Query returned {row_count} row(s):\n",
header,
separator,
"\n".join(row_lines),
]
return "\n".join(result_parts)
@tool
async def query_althaus_database(
sql_query: Annotated[
str, "The SQL SELECT query to execute against the Althaus database"
],
) -> str:
"""Execute a SELECT query against the Althaus database via REST API.
Use this tool to query data from the Althaus database. Only SELECT statements
are allowed for security reasons. The query will be forwarded to an external
REST API and the results will be returned in a formatted table.
Args:
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
Returns:
A formatted string containing the query results with columns and rows
"""
try:
# Validate the query
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
if not is_valid:
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
return f"Error: {error_msg}"
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
# Mock the external REST API call
# In production, this would be replaced with actual REST API call:
# response = await httpx.AsyncClient().post(
# "https://api.althaus.example.com/query",
# json={"query": sql_query},
# headers={"Authorization": f"Bearer {api_key}"}
# )
# result = response.json()
result = await _mock_api_call(sql_query=sql_query)
# Format and return results
formatted_output = _format_results(
columns=result["columns"],
rows=result["rows"],
row_count=result["row_count"],
)
logger.info(
f"Query completed successfully, returned {result['row_count']} row(s)"
)
return formatted_output
except Exception as e:
logger.error(f"Error in query_althaus_database tool: {str(e)}")
return f"Error executing query: {str(e)}"

View file

@ -1,362 +0,0 @@
"""Power BI Query Tool for LangGraph.
This tool provides DAX query capabilities for Power BI datasets
via the Power BI REST API. Only read-only queries are allowed.
"""
import logging
import asyncio
import os
import re
import functools
from typing import Annotated
import anyio
import httpx
from langchain_core.tools import tool
from msal import ConfidentialClientApplication, SerializableTokenCache
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
# Configuration constants - encapsulated in this file
POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "")
POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "")
POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "")
POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "")
POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg"
POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com"
POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"]
# Limit results to prevent excessive context usage
MAX_ROWS_LIMIT = 100
def _validate_environment() -> tuple[bool, str]:
"""Validate that all required environment variables are set.
Returns:
A tuple of (is_valid, error_message)
"""
missing = []
if not POWERBI_DATASET_ID:
missing.append("POWERBI_DATASET_ID")
if not POWERBI_CLIENT_ID:
missing.append("POWERBI_CLIENT_ID")
if not POWERBI_CLIENT_SECRET:
missing.append("POWERBI_CLIENT_SECRET")
if not POWERBI_TENANT_ID:
missing.append("POWERBI_TENANT_ID")
if missing:
return False, f"Missing required environment variables: {', '.join(missing)}"
return True, ""
def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]:
"""Validate that the query is a valid DAX query.
Args:
dax_query: The DAX query to validate
Returns:
A tuple of (is_valid, error_message)
"""
# Remove leading/trailing whitespace
normalized_query = dax_query.strip()
if not normalized_query:
return False, "Query cannot be empty"
# DAX queries typically start with EVALUATE, DEFINE, or are table expressions
# We'll be lenient and just check it's not trying to do something dangerous
# DAX is read-only by nature, but we validate structure
# Check for minimum length
if len(normalized_query) < 5:
return False, "Query is too short to be valid"
return True, ""
def _get_access_token_sync(
*,
tenant_id: str,
client_id: str,
client_secret: str,
authority_base: str = POWERBI_AUTHORITY_BASE,
cache: SerializableTokenCache | None = None,
) -> str:
"""Get Power BI access token using MSAL (synchronous).
Args:
tenant_id: Azure AD tenant ID
client_id: Application client ID
client_secret: Application client secret
authority_base: Azure AD authority base URL
cache: Optional token cache for reuse
Returns:
Access token string
Raises:
RuntimeError: If token acquisition fails
"""
authority = f"{authority_base}/{tenant_id}"
app = ConfidentialClientApplication(
client_id=client_id,
authority=authority,
client_credential=client_secret,
token_cache=cache,
)
# Try cache first; fall back to client credentials
result = app.acquire_token_silent(
POWERBI_SCOPE, account=None
) or app.acquire_token_for_client(scopes=POWERBI_SCOPE)
if "access_token" not in result:
raise RuntimeError(
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
)
return result["access_token"]
async def _get_access_token_async(
*,
tenant_id: str,
client_id: str,
client_secret: str,
**kwargs,
) -> str:
"""Get Power BI access token using MSAL (asynchronous).
Args:
tenant_id: Azure AD tenant ID
client_id: Application client ID
client_secret: Application client secret
**kwargs: Additional arguments for _get_access_token_sync
Returns:
Access token string
"""
# Create a partial function with arguments pre-filled
func = functools.partial(
_get_access_token_sync,
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
**kwargs,
)
# Offload the blocking MSAL HTTP call to a worker thread
return await anyio.to_thread.run_sync(func)
async def _execute_dax_query(
*, dax_query: str, dataset_id: str, access_token: str
) -> dict:
"""Execute a DAX query against Power BI dataset.
Args:
dax_query: The DAX query to execute
dataset_id: Power BI dataset ID
access_token: Access token for authentication
Returns:
Dictionary containing query results
Raises:
RuntimeError: If query execution fails
"""
url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries"
body = {
"queries": [{"query": dax_query}],
"serializerSettings": {"includeNulls": True},
}
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(url, headers=headers, json=body)
if resp.status_code != 200:
raise RuntimeError(
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
)
payload = resp.json()
try:
rows = payload["results"][0]["tables"][0]["rows"]
except (KeyError, IndexError) as e:
raise RuntimeError("Unexpected executeQueries response structure") from e
# Extract column names from the first row if available
if rows:
columns = list(rows[0].keys())
else:
columns = []
return {"columns": columns, "rows": rows}
def _strip_table_qualifier(*, column_name: str) -> str:
"""Strip table qualifier from column name.
Power BI often returns columns as 'Table[Column]'. This strips to 'Column'.
Args:
column_name: The column name to process
Returns:
Processed column name
"""
if "[" in column_name and column_name.endswith("]"):
return column_name.split("[", 1)[1][:-1]
return column_name
def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str:
"""Format query results into a readable string.
Args:
columns: List of column names
rows: List of row data (as dictionaries)
max_rows: Maximum number of rows to display
Returns:
Formatted string representation of the results
"""
total_rows = len(rows)
if total_rows == 0:
return "Query executed successfully but returned no results."
# Strip table qualifiers from column names
clean_columns = [_strip_table_qualifier(column_name=col) for col in columns]
# Limit rows to max_rows
display_rows = rows[:max_rows]
truncated = total_rows > max_rows
# Calculate column widths
col_widths = [len(str(col)) for col in clean_columns]
for row in display_rows:
for i, col in enumerate(columns):
value = row.get(col, "")
col_widths[i] = max(col_widths[i], len(str(value)))
# Build header
header_parts = []
for col, width in zip(clean_columns, col_widths):
header_parts.append(str(col).ljust(width))
header = " | ".join(header_parts)
separator = "-" * len(header)
# Build rows
row_lines = []
for row in display_rows:
row_parts = []
for col, width in zip(columns, col_widths):
value = row.get(col, "")
row_parts.append(str(value).ljust(width))
row_lines.append(" | ".join(row_parts))
# Combine all parts
result_parts = [
f"Query returned {total_rows} row(s):",
]
if truncated:
result_parts.append(
f"(Results limited to {max_rows} rows for context efficiency)\n"
)
else:
result_parts.append("")
result_parts.extend([header, separator, "\n".join(row_lines)])
return "\n".join(result_parts)
@tool
async def query_powerbi_data(
dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"],
) -> str:
"""Execute a DAX query against the Power BI dataset to access warehouse inventory data.
This tool provides access to a Power BI table called 'data_full' which contains
articles available in the warehouse of the user. Use DAX (Data Analysis Expressions)
queries to retrieve and analyze this inventory data.
Available table:
- 'data_full': Contains warehouse inventory articles and their details
Common query patterns:
- View all data: EVALUATE 'data_full'
- With filter: EVALUATE FILTER('data_full', [Column] = "Value")
- Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC)
- Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2]))
Results are limited to 100 rows maximum for efficiency.
Args:
dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'")
Returns:
A formatted string containing the query results with columns and rows
"""
try:
# Validate environment configuration
is_valid_env, error_msg = _validate_environment()
if not is_valid_env:
logger.error(f"Environment validation failed: {error_msg}")
return f"Configuration Error: {error_msg}"
# Validate the query
is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query)
if not is_valid_query:
logger.warning(f"Invalid query attempt: {dax_query[:100]}...")
return f"Query Validation Error: {error_msg}"
logger.info(f"Executing Power BI query: {dax_query[:100]}...")
# Get access token
access_token = await _get_access_token_async(
tenant_id=POWERBI_TENANT_ID,
client_id=POWERBI_CLIENT_ID,
client_secret=POWERBI_CLIENT_SECRET,
)
# Execute the query
result = await _execute_dax_query(
dax_query=dax_query,
dataset_id=POWERBI_DATASET_ID,
access_token=access_token,
)
# Format and return results
formatted_output = _format_results(
columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT
)
logger.info(
f"Query completed successfully, returned {len(result['rows'])} row(s)"
)
return formatted_output
except RuntimeError as e:
logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}")
return f"Error executing query: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}")
return f"Unexpected error: {str(e)}"

View file

@ -1,7 +0,0 @@
"""Shared tools available across all chatbot implementations."""
from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import (
tavily_search,
)
__all__ = ["tavily_search"]

View file

@ -1,24 +0,0 @@
"""Tool for sending streaming status updates to users."""
from langchain_core.tools import tool
@tool
def send_streaming_message(message: str) -> str:
"""Send a streaming message to the user to provide updates during processing.
Use this tool to send short status updates to the user while you are working
on their request. This helps keep the user informed about what you are doing.
Args:
message: A short message describing what you are currently doing.
Examples: "Searching database for relevant information..."
"Analyzing search results..."
"Processing your request..."
Returns:
A confirmation that the message was sent.
"""
# This tool doesn't actually do anything - it's just for the AI to signal
# what it's doing to the frontend via the tool call mechanism
return f"Status update sent: {message}"

View file

@ -1,55 +0,0 @@
"""Tavily Search Tool for LangGraph.
This tool provides web search capabilities using the Tavily API.
"""
import logging
from typing import Annotated
from langchain_core.tools import tool
from modules.connectors.connectorAiTavily import ConnectorWeb
logger = logging.getLogger(__name__)
@tool
async def tavily_search(
query: Annotated[str, "The search query to look up on the web"],
) -> str:
"""Search the web using Tavily API.
Use this tool to search for current information, news, or any web content.
The tool returns relevant search results including titles and URLs.
Args:
query: The search query string
Returns:
A formatted string containing search results with titles and URLs
"""
try:
# Create connector instance
connector = await ConnectorWeb.create()
# Perform search with default parameters
results = await connector._search(
query=query,
max_results=5,
search_depth="basic",
include_answer=True,
include_raw_content=False,
)
# Format results
if not results:
return f"No results found for query: {query}"
formatted_results = [f"Search results for '{query}':\n"]
for i, result in enumerate(results, 1):
formatted_results.append(f"{i}. {result.title}")
formatted_results.append(f" URL: {result.url}\n")
return "\n".join(formatted_results)
except Exception as e:
logger.error(f"Error in tavily_search tool: {str(e)}")
return f"Error performing search: {str(e)}"

View file

@ -1 +0,0 @@
"""Domain logic for chatbot functionality."""

View file

@ -1,301 +0,0 @@
"""Chatbot domain logic with LangGraph integration."""
from dataclasses import dataclass
from typing import Annotated, AsyncIterator, Any
import logging
from pydantic import BaseModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
from modules.features.chatBot.utils.toolRegistry import get_registry
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
class ChatState(BaseModel):
"""Represents the state of a chat session."""
messages: Annotated[list[BaseMessage], add_messages]
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
"""Map permission model names to LangChain ChatAnthropic models.
Args:
model_name: The model name from permissions (e.g., "claude_4_5")
Returns:
Configured ChatAnthropic instance
Raises:
ValueError: If the model name is not supported
"""
# Model name mapping
model_mapping = {
"claude_4_5": "claude-sonnet-4-5",
# Add more mappings as needed
}
anthropic_model = model_mapping.get(model_name)
if not anthropic_model:
logger.warning(
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
)
anthropic_model = "claude-4-5-sonnet"
return ChatAnthropic(
model=anthropic_model,
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
)
@dataclass
class Chatbot:
"""Represents a chatbot with LangGraph integration."""
model: Any
memory: Any
app: Any = None
system_prompt: str = "You are a helpful assistant."
context_window_size: int = 100000
@classmethod
async def create(
cls,
*,
model: Any,
memory: Any,
system_prompt: str,
tools: list,
context_window_size: int = 100000,
) -> "Chatbot":
"""Factory method to create and configure a Chatbot instance.
Args:
model: The chat model to use.
memory: The chat memory checkpointer to use.
system_prompt: The system prompt to initialize the chatbot.
tools: List of LangChain tools the chatbot can use.
context_window_size: Maximum tokens for context window.
Returns:
A configured Chatbot instance.
"""
instance = cls(
model=model,
memory=memory,
system_prompt=system_prompt,
context_window_size=context_window_size,
)
instance.app = instance._build_app(memory=memory, tools=tools)
return instance
def _build_app(
self, *, memory: Any, tools: list
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
"""Builds the chatbot application workflow using LangGraph.
Args:
memory: The chat memory checkpointer to use.
tools: The list of tools the chatbot can use.
Returns:
A compiled state graph representing the chatbot application.
"""
llm_with_tools = self.model.bind_tools(tools=tools)
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
"""Selects a window of messages that fit within the context window size.
Args:
msgs: The list of messages to select from.
Returns:
A list of messages that fit within the context window size.
"""
def approx_counter(items: list[BaseMessage]) -> int:
"""Approximate token counter for messages.
Args:
items: List of messages to count tokens for.
Returns:
Approximate number of tokens in the messages.
"""
return sum(len(getattr(m, "content", "") or "") for m in items)
return trim_messages(
msgs,
strategy="last",
token_counter=approx_counter,
max_tokens=self.context_window_size,
start_on="human",
end_on=("human", "tool"),
include_system=True,
)
def agent_node(state: ChatState) -> dict:
"""Agent node for the chatbot workflow.
Args:
state: The current chat state.
Returns:
The updated chat state after processing.
"""
# Select the message window to fit in context (trim if needed)
window = select_window(state.messages)
# Ensure the system prompt is present at the start
if not window or not isinstance(window[0], SystemMessage):
window = [SystemMessage(content=self.system_prompt)] + window
# Call the LLM with tools
response = llm_with_tools.invoke(window)
# Return the new state
return {"messages": [response]}
def should_continue(state: ChatState) -> str:
"""Determines whether to continue the workflow or end it.
This conditional edge is called after the agent node to decide
whether to continue to the tools node (if the last message contains
tool calls) or to end the workflow (if no tool calls are present).
Args:
state: The current chat state.
Returns:
The next node to transition to ("tools" or END).
"""
# Get the last message
last_message = state.messages[-1]
# Check if the last message contains tool calls
# If so, continue to the tools node; otherwise, end the workflow
return "tools" if getattr(last_message, "tool_calls", None) else END
# Compose the workflow
workflow = StateGraph(ChatState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", ToolNode(tools=tools))
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", "agent")
return workflow.compile(checkpointer=memory)
async def chat(
self, *, message: str, chat_id: str = "default"
) -> list[BaseMessage]:
"""Processes a chat message and returns the chat history.
Args:
message: The user message to process.
chat_id: The chat thread ID.
Returns:
The list of messages in the chat history.
"""
# Set the right thread ID for memory
config = {"configurable": {"thread_id": chat_id}}
# Single-turn chat (non-streaming)
result = await self.app.ainvoke(
{"messages": [HumanMessage(content=message)]}, config=config
)
# Extract and return the messages from the result
return result["messages"]
async def stream_events(
self, *, message: str, chat_id: str = "default"
) -> AsyncIterator[dict]:
"""Stream UI-focused events using astream_events v2.
Args:
message: The user message to process.
chat_id: Logical thread identifier; forwarded in the runnable config so
memory and tools are scoped per thread.
Yields:
dict: One of:
- ``{"type": "status", "label": str}`` for short progress updates.
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
where ``chat_history`` only includes ``user``/``assistant`` roles.
- ``{"type": "error", "message": str}`` if an exception occurs.
"""
# Thread-aware config for LangGraph/LangChain
config = {"configurable": {"thread_id": chat_id}}
def _is_root(ev: dict) -> bool:
"""Return True if the event is from the root run (v2: empty parent_ids)."""
return not ev.get("parent_ids")
try:
async for event in self.app.astream_events(
{"messages": [HumanMessage(content=message)]},
config=config,
version="v2",
):
etype = event.get("event")
ename = event.get("name") or ""
edata = event.get("data") or {}
# Stream human-readable progress via the special send_streaming_message tool
if etype == "on_tool_start" and ename == "send_streaming_message":
tool_in = edata.get("input") or {}
msg = tool_in.get("message")
if isinstance(msg, str) and msg.strip():
yield {"type": "status", "label": msg.strip()}
continue
# Emit the final payload when the root run finishes
if etype == "on_chain_end" and _is_root(event):
output_obj = edata.get("output")
# Extract message list from the graph's final output
final_msgs = ChatStreamingHelper.extract_messages_from_output(
output_obj=output_obj
)
# Normalize for the frontend (only user/assistant with text content)
chat_history_payload: list[dict] = []
for m in final_msgs:
if isinstance(m, BaseMessage):
d = ChatStreamingHelper.message_to_dict(msg=m)
elif isinstance(m, dict):
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
else:
continue
if d.get("role") in ("user", "assistant") and d.get("content"):
chat_history_payload.append(d)
yield {
"type": "final",
"response": {
"thread": chat_id,
"chat_history": chat_history_payload,
},
}
return
except Exception as exc:
# Emit a single error envelope and end the stream
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
yield {"type": "error", "message": f"Error processing request: {exc}"}

View file

@ -1,239 +0,0 @@
"""Streaming helper utilities for chat message processing and normalization."""
from typing import Any, Dict, List, Literal, Mapping, Optional
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
Role = Literal["user", "assistant", "system", "tool"]
class ChatStreamingHelper:
"""Pure helper methods for streaming and message normalization.
This class provides static utility methods for converting between different
message formats, extracting content, and normalizing message structures
for streaming chat applications.
"""
@staticmethod
def role_from_message(*, msg: BaseMessage) -> Role:
"""Extract the role from a BaseMessage instance.
Args:
msg: The BaseMessage instance to extract the role from.
Returns:
The role as a string literal: "user", "assistant", "system", or "tool".
Defaults to "assistant" if the message type is not recognized.
Examples:
>>> from langchain_core.messages import HumanMessage
>>> msg = HumanMessage(content="Hello")
>>> ChatStreamingHelper.role_from_message(msg=msg)
'user'
"""
if isinstance(msg, HumanMessage):
return "user"
if isinstance(msg, AIMessage):
return "assistant"
if isinstance(msg, SystemMessage):
return "system"
if isinstance(msg, ToolMessage):
return "tool"
return getattr(msg, "role", "assistant")
@staticmethod
def flatten_content(*, content: Any) -> str:
"""Convert complex content structures to plain text.
This method handles various content formats including strings, lists of
content parts, and dictionaries with text fields. It's designed to
normalize content from different message sources into a consistent
plain text format.
Args:
content: The content to flatten. Can be:
- str: Returned as-is after stripping whitespace
- list: Each item processed and joined with newlines
- dict: Text extracted from "text" or "content" fields
- None: Returns empty string
- Any other type: Converted to string
Returns:
The flattened content as a plain text string with whitespace stripped.
Examples:
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
>>> ChatStreamingHelper.flatten_content(content=content)
'Hello
nworld'
>>> content = {"text": "Simple message"}
>>> ChatStreamingHelper.flatten_content(content=content)
'Simple message'
"""
if content is None:
return ""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
if "text" in part and isinstance(part["text"], str):
parts.append(part["text"])
elif part.get("type") == "text" and isinstance(
part.get("text"), str
):
parts.append(part["text"])
elif "content" in part and isinstance(part["content"], str):
parts.append(part["content"])
else:
# Fallback for unknown dictionary structures
val = part.get("value")
if isinstance(val, str):
parts.append(val)
else:
parts.append(str(part))
return "\n".join(p.strip() for p in parts if p is not None)
if isinstance(content, dict):
if "text" in content and isinstance(content["text"], str):
return content["text"].strip()
if "content" in content and isinstance(content["content"], str):
return content["content"].strip()
return str(content).strip()
@staticmethod
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
"""Convert a BaseMessage instance to a dictionary for streaming output.
This method normalizes BaseMessage instances into a consistent dictionary
format suitable for JSON serialization and streaming to clients.
Args:
msg: The BaseMessage instance to convert.
Returns:
A dictionary containing:
- "role": The message role (user, assistant, system, tool)
- "content": The flattened message content as plain text
- "tool_calls": Tool calls if present (optional)
- "name": Message name if present (optional)
Examples:
>>> from langchain_core.messages import HumanMessage
>>> msg = HumanMessage(content="Hello there")
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
>>> result["role"]
'user'
>>> result["content"]
'Hello there'
"""
payload: Dict[str, Any] = {
"role": ChatStreamingHelper.role_from_message(msg=msg),
"content": ChatStreamingHelper.flatten_content(
content=getattr(msg, "content", "")
),
}
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
payload["tool_calls"] = tool_calls
name = getattr(msg, "name", None)
if name:
payload["name"] = name
return payload
@staticmethod
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
"""Convert a dictionary-shaped message to a normalized dictionary.
This method handles messages that come from serialized state and are
represented as dictionaries rather than BaseMessage instances. It
normalizes various dictionary formats into a consistent structure.
Args:
obj: The dictionary-shaped message to convert. Expected to contain
fields like "role", "type", "content", "text", etc.
Returns:
A normalized dictionary containing:
- "role": The message role (user, assistant, system, tool)
- "content": The flattened message content as plain text
- "tool_calls": Tool calls if present (optional)
- "name": Message name if present (optional)
Examples:
>>> obj = {"type": "human", "content": "Hello"}
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
>>> result["role"]
'user'
>>> result["content"]
'Hello'
"""
role: Optional[str] = obj.get("role")
if not role:
# Handle alternative type field mappings
typ = obj.get("type")
if typ in ("human", "user"):
role = "user"
elif typ in ("ai", "assistant"):
role = "assistant"
elif typ in ("system",):
role = "system"
elif typ in ("tool", "function"):
role = "tool"
content = obj.get("content")
if content is None and "text" in obj:
content = obj["text"]
out: Dict[str, Any] = {
"role": role or "assistant",
"content": ChatStreamingHelper.flatten_content(content=content),
}
if "tool_calls" in obj:
out["tool_calls"] = obj["tool_calls"]
if obj.get("name"):
out["name"] = obj["name"]
return out
@staticmethod
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
"""Extract messages from LangGraph output objects.
This method handles various output formats from LangGraph execution,
extracting the messages list from different possible structures.
Args:
output_obj: The output object from LangGraph execution. Can be:
- An object with a "messages" attribute
- A dictionary with a "messages" key
- Any other object (returns empty list)
Returns:
A list of extracted messages, or an empty list if no messages
are found or if the output object is None.
Examples:
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
>>> len(messages)
1
"""
if output_obj is None:
return []
# Try to parse dicts first
if isinstance(output_obj, dict):
msgs = output_obj.get("messages")
return msgs if isinstance(msgs, list) else []
# Then try to get messages attribute
msgs = getattr(output_obj, "messages", None)
return msgs if isinstance(msgs, list) else []

File diff suppressed because it is too large Load diff

View file

@ -1,197 +0,0 @@
from typing import AsyncIterator
import uuid
from datetime import datetime, timezone
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint
class Base(DeclarativeBase):
pass
# Tools Table
class Tool(Base):
"""Available chatbot tools.
Stores information about all available tools that can be assigned to users.
Each tool has a unique tool_id that corresponds to the registry tool_id.
"""
__tablename__ = "tools"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
label: Mapped[str] = mapped_column(String(255), nullable=False)
category: Mapped[str] = mapped_column(String(50), nullable=False)
description: Mapped[str] = mapped_column(String(1000), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
date_created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
date_updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# User-Tool Mapping Table
class UserToolMapping(Base):
"""Mapping of users to their available tools.
Many-to-many relationship between users and tools.
- One user can have multiple tools
- One tool can be assigned to multiple users
The combination of user_id and tool_id is unique.
"""
__tablename__ = "user_tools"
__table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),)
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
date_granted: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
date_updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# User Thread Mapping Table
class UserThreadMapping(Base):
"""Mapping of users to their chat threads.
Used to keep track of which user owns which chat thread.
Also stores meta data like thread name.
1:N relationship between user and thread. A thread belongs to exactly one user.
A user can have multiple threads.
Thread_id is unique in the table.
"""
__tablename__ = "user_threads"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
thread_name: Mapped[str] = mapped_column(String(255), nullable=False)
date_created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
date_updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# Dependency that pulls the sessionmaker off app.state
# This is set in app.py on startup in @asynccontextmanager
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
SessionLocal: async_sessionmaker[AsyncSession] = (
request.app.state.checkpoint_sessionmaker
)
async with SessionLocal() as session:
yield session
# Optional helper to init tables at startup (demo only)
async def init_models(engine) -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def sync_tools_from_registry(session: AsyncSession) -> None:
"""Sync tools from tool registry to database.
This function:
- Adds new tools from the registry to the database
- Updates existing tools with current registry information
- Marks tools not present in the registry as inactive
Should be called on application startup after database initialization.
Args:
session: Active database session
"""
import logging
from sqlalchemy import select
from modules.features.chatBot.utils.toolRegistry import get_registry
logger = logging.getLogger(__name__)
logger.info("Syncing tools from registry to database...")
# Get all tools from the registry
registry = get_registry()
registry_tools = registry.get_all_tools()
# Create a set of tool_ids from the registry
registry_tool_ids = {tool.tool_id for tool in registry_tools}
logger.info(f"Found {len(registry_tools)} tools in registry")
# Get all existing tools from the database
result = await session.execute(select(Tool))
db_tools = result.scalars().all()
db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools}
logger.info(f"Found {len(db_tools)} tools in database")
# Track changes
added_count = 0
updated_count = 0
deactivated_count = 0
# Sync tools from registry to database
for registry_tool in registry_tools:
if registry_tool.tool_id in db_tools_by_tool_id:
# Tool exists - update it
# Preserve label and description (user-editable fields)
db_tool = db_tools_by_tool_id[registry_tool.tool_id]
db_tool.name = registry_tool.name
db_tool.category = registry_tool.category
db_tool.is_active = True
db_tool.date_updated = datetime.now(timezone.utc)
updated_count += 1
logger.debug(f"Updated tool: {registry_tool.tool_id}")
else:
# Tool doesn't exist - create it
new_tool = Tool(
tool_id=registry_tool.tool_id,
name=registry_tool.name,
label=registry_tool.tool_id, # Use tool_id as label per spec
category=registry_tool.category,
description=registry_tool.description or "",
is_active=True,
)
session.add(new_tool)
added_count += 1
logger.debug(f"Added new tool: {registry_tool.tool_id}")
# Mark tools not in registry as inactive
for db_tool in db_tools:
if db_tool.tool_id not in registry_tool_ids and db_tool.is_active:
db_tool.is_active = False
db_tool.date_updated = datetime.now(timezone.utc)
deactivated_count += 1
logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}")
logger.info(
f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated"
)

View file

@ -1,106 +0,0 @@
"""PostgreSQL checkpointer utilities for LangGraph memory."""
import sys
import asyncio
import logging
from typing import Optional
# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply)
if sys.platform == 'win32':
try:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
except RuntimeError:
pass # Already set
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
from psycopg.rows import dict_row
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
# Global checkpointer instance
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
_connection_pool: Optional[AsyncConnectionPool] = None
async def initialize_checkpointer() -> None:
"""Initialize the PostgreSQL checkpointer for LangGraph.
This should be called during application startup.
Creates a connection pool and PostgresSaver instance.
"""
global _checkpointer_instance, _connection_pool
if _checkpointer_instance is not None:
logger.info("Checkpointer already initialized")
return
try:
# Get database configuration from environment
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
# Build connection string
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
# Create async connection pool
_connection_pool = AsyncConnectionPool(
conninfo=connection_string,
min_size=2,
max_size=10,
kwargs={"autocommit": True, "row_factory": dict_row},
)
# Initialize the connection pool
await _connection_pool.open()
# Create AsyncPostgresSaver with the pool
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
# Setup the checkpointer (creates tables if needed)
await _checkpointer_instance.setup()
logger.info("PostgreSQL checkpointer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
raise
async def close_checkpointer() -> None:
"""Close the checkpointer and connection pool.
This should be called during application shutdown.
"""
global _checkpointer_instance, _connection_pool
if _connection_pool is not None:
try:
await _connection_pool.close()
logger.info("PostgreSQL checkpointer connection pool closed")
except Exception as e:
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
_checkpointer_instance = None
_connection_pool = None
def get_checkpointer() -> AsyncPostgresSaver:
"""Get the global PostgreSQL checkpointer instance.
Returns:
The initialized AsyncPostgresSaver instance
Raises:
RuntimeError: If checkpointer is not initialized
"""
if _checkpointer_instance is None:
raise RuntimeError(
"PostgreSQL checkpointer not initialized. "
"Call initialize_checkpointer() during application startup."
)
return _checkpointer_instance

View file

@ -1,39 +0,0 @@
"""Mock permissions module for chatbot access control.
This module provides mock permission functions that will be replaced
with actual database-driven permissions in the future.
"""
from datetime import datetime
from modules.features.chatBot.utils.toolRegistry import get_registry
# TODO: Replace these mock implementations with actual database queries
def get_chatbot_tools(*, user_id: str) -> list[str]:
"""Get list of tool IDs that the chatbot can use for a given user."""
registry = get_registry()
return registry.list_tool_ids()
def get_chatbot_model(*, user_id: str) -> str:
"""Gets the chatbot model(s) a user is allowed to use."""
return "claude_4_5"
def get_system_prompt(*, user_id: str) -> str:
"""Get the system prompt for a user's chatbot session.
This is a mock implementation that returns a generic prompt with today's date.
In production, this will query the database for user-specific or role-specific prompts.
Args:
user_id: The unique identifier of the user
Returns:
The system prompt string with the current date
"""
current_date = datetime.now().strftime("%Y-%m-%d")
return f"You're a smart assistant. Today is {current_date}"

View file

@ -1,305 +0,0 @@
"""Tool registry for auto-discovering and managing chatbot tools.
This module provides a central registry that automatically discovers all tools
in the chatbotTools directory structure and provides methods to query them.
The registry is built in-memory at startup and does not require a database.
"""
import importlib
import inspect
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)
@dataclass
class ToolMetadata:
"""Metadata about a discovered chatbot tool.
Attributes:
tool_id: Unique identifier (e.g., 'shared.tavily_search')
name: Function name of the tool
category: Category of the tool ('shared' or 'customer')
description: Tool description from docstring
tool_instance: The actual LangChain tool instance
module_path: Full Python module path
"""
tool_id: str
name: str
category: str
description: str
tool_instance: BaseTool
module_path: str
def __str__(self) -> str:
"""Return a pretty-printed string representation for logging."""
return (
f"ToolMetadata(\n"
f" tool_id='{self.tool_id}',\n"
f" name='{self.name}',\n"
f" category='{self.category}',\n"
f" description='{self.description}',\n"
f" module_path='{self.module_path}'\n"
f")"
)
class ToolRegistry:
"""Central registry for all chatbot tools.
This class discovers and catalogs all tools decorated with @tool in the
chatbotTools directory structure. Tools are automatically discovered at
initialization by scanning the filesystem.
The registry provides methods to query tools by ID, category, or get all tools.
"""
def __init__(self) -> None:
"""Initialize an empty tool registry."""
self._tools: Dict[str, ToolMetadata] = {}
self._initialized: bool = False
def initialize(self) -> None:
"""Discover and register all tools from the chatbotTools directory.
This method scans both sharedTools and customerTools directories,
imports all tool*.py modules, and extracts functions decorated with @tool.
This method is idempotent - calling it multiple times has no effect
after the first initialization.
"""
if self._initialized:
logger.debug("Tool registry already initialized, skipping")
return
logger.info("Initializing tool registry...")
# Get base path to chatbotTools directory
base_path = Path(__file__).parent.parent / "chatbotTools"
if not base_path.exists():
logger.warning(f"chatbotTools directory not found at {base_path}")
self._initialized = True
return
# Discover tools in each category
self._discover_category(
category_path=base_path / "sharedTools", category="shared"
)
self._discover_category(
category_path=base_path / "customerTools", category="customer"
)
self._initialized = True
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
def _discover_category(self, *, category_path: Path, category: str) -> None:
"""Discover all tools in a specific category directory.
Args:
category_path: Path to the category directory (sharedTools or customerTools)
category: Category name ('shared' or 'customer')
"""
if not category_path.exists():
logger.warning(f"Category directory not found: {category_path}")
return
logger.debug(f"Discovering tools in category: {category}")
# Find all tool*.py files (excluding __init__.py)
tool_files = [
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
]
for tool_file in tool_files:
self._import_and_register_tools(
tool_file=tool_file, category=category, category_path=category_path
)
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
def _import_and_register_tools(
self, *, tool_file: Path, category: str, category_path: Path
) -> None:
"""Import a tool module and register all discovered tools.
Args:
tool_file: Path to the tool Python file
category: Category name ('shared' or 'customer')
category_path: Path to the category directory
"""
# Construct module name
module_name = (
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
)
try:
# Import the module
module = importlib.import_module(module_name)
# Find all BaseTool instances in the module
tools_found = 0
for name, obj in inspect.getmembers(module):
if isinstance(obj, BaseTool):
self._register_tool(
tool_instance=obj,
name=name,
category=category,
module_path=module_name,
)
tools_found += 1
if tools_found == 0:
logger.warning(f"No tools found in {module_name}")
else:
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
except ImportError as e:
logger.error(
f"Import error loading tools from {module_name}: {str(e)}. "
f"This tool will not be available."
)
except Exception as e:
logger.error(
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
)
def _register_tool(
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
) -> None:
"""Register a single tool in the registry.
Args:
tool_instance: The LangChain tool instance
name: Function name of the tool
category: Category name ('shared' or 'customer')
module_path: Full Python module path
"""
tool_id = f"{category}.{name}"
# Check for duplicate tool IDs
if tool_id in self._tools:
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
metadata = ToolMetadata(
tool_id=tool_id,
name=name,
category=category,
description=tool_instance.description or "",
tool_instance=tool_instance,
module_path=module_path,
)
self._tools[tool_id] = metadata
logger.debug(f"Registered tool: {tool_id}")
def get_all_tools(self) -> List[ToolMetadata]:
"""Get all registered tools.
Returns:
List of all tool metadata objects
"""
return list(self._tools.values())
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
"""Get a specific tool by its ID.
Args:
tool_id: The tool identifier (e.g., 'shared.tavily_search')
Returns:
Tool metadata if found, None otherwise
"""
return self._tools.get(tool_id)
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
"""Get all tools in a specific category.
Args:
category: Category name ('shared' or 'customer')
Returns:
List of tool metadata for the specified category
"""
return [t for t in self._tools.values() if t.category == category]
def list_tool_ids(self) -> List[str]:
"""Get a list of all registered tool IDs.
Returns:
List of tool ID strings
"""
return list(self._tools.keys())
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
"""Get actual tool instances for a list of tool IDs.
This is useful for filtering tools based on user permissions.
Args:
tool_ids: List of tool IDs to retrieve
Returns:
List of BaseTool instances for the specified IDs
"""
instances = []
for tool_id in tool_ids:
metadata = self.get_tool(tool_id=tool_id)
if metadata:
instances.append(metadata.tool_instance)
else:
logger.warning(f"Tool ID not found in registry: {tool_id}")
return instances
@property
def is_initialized(self) -> bool:
"""Check if the registry has been initialized.
Returns:
True if initialized, False otherwise
"""
return self._initialized
# Global registry instance
_registry: Optional[ToolRegistry] = None
def get_registry() -> ToolRegistry:
"""Get the global tool registry instance.
This function ensures the registry is initialized on first access.
Subsequent calls return the same instance.
Returns:
The global ToolRegistry instance
"""
global _registry
if _registry is None:
_registry = ToolRegistry()
if not _registry.is_initialized:
_registry.initialize()
return _registry
def reinitialize_registry() -> ToolRegistry:
"""Force reinitialize the tool registry.
This is useful for testing or when tools are added dynamically.
Returns:
The reinitialized ToolRegistry instance
"""
global _registry
_registry = ToolRegistry()
_registry.initialize()
return _registry

View file

@ -13,14 +13,15 @@ async def start() -> None:
from modules.features.syncDelta import mainSyncDelta
mainSyncDelta.startSyncManager(eventUser)
# Feature ChatBot
from modules.features.chatBot.mainChatBot import start as startChatBot
await startChatBot()
# Feature ...
return True
async def stop() -> None:
""" Stop feature triggers and background managers """
# Feature ChatBot
from modules.features.chatBot.mainChatBot import stop as stopChatBot
await stopChatBot()
# Feature ...
return True

View file

@ -1,653 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.requests import Request
from fastapi.responses import StreamingResponse
import logging
from modules.datamodels.datamodelUam import User, UserPrivilege
from modules.security.auth import getCurrentUser, limiter
from sqlalchemy.ext.asyncio import AsyncSession
from modules.features.chatBot.subChatbotDatabase import get_async_db_session
from modules.features.chatBot.mainChatBot import (
get_or_create_thread_for_user,
)
from modules.datamodels.datamodelChatbot import (
ChatMessageRequest,
MessageItem,
ChatMessageResponse,
ThreadSummary,
ThreadListResponse,
ThreadDetail,
RenameThreadRequest,
DeleteResponse,
ToolListResponse,
ToolInfo,
GrantToolRequest,
GrantToolResponse,
RevokeToolRequest,
RevokeToolResponse,
UpdateToolRequest,
UpdateToolResponse,
)
from modules.features.chatBot import mainChatBot as chat_service
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/chatbot",
tags=["Chatbot"],
responses={404: {"description": "Not found"}},
)
# --- Actual endpoints for chatbot ---
@router.post("/message/stream")
@limiter.limit("30/minute")
async def post_chat_message_stream(
*,
request: Request,
message_request: ChatMessageRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> StreamingResponse:
"""
Post a message to a chat thread with streaming progress updates.
Creates a new thread if thread_id is not provided.
Returns Server-Sent Events (SSE) stream with status updates and final response.
"""
try:
# Validate and get tools for the request
tool_ids = await chat_service.validate_and_get_tools_for_request(
user_id=currentUser.id,
requested_tool_ids=message_request.tools,
session=session,
)
# Get or create thread using helper function
thread_id = await get_or_create_thread_for_user(
thread_id=message_request.thread_id,
user=currentUser,
session=session,
thread_name=message_request.message[:100],
refresh_date_updated=True,
)
logger.info(
f"User {currentUser.id} posted streaming message to thread {thread_id}"
)
return StreamingResponse(
chat_service.post_message_stream(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
tool_ids=tool_ids,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
except PermissionError as e:
logger.error(f"Permission error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.post("/message", response_model=ChatMessageResponse)
@limiter.limit("30/minute")
async def post_chat_message(
*,
request: Request,
message_request: ChatMessageRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ChatMessageResponse:
"""
Post a message to a chat thread and get assistant response (non-streaming).
Creates a new thread if thread_id is not provided.
For streaming updates, use the /message/stream endpoint instead.
"""
try:
# Validate and get tools for the request
tool_ids = await chat_service.validate_and_get_tools_for_request(
user_id=currentUser.id,
requested_tool_ids=message_request.tools,
session=session,
)
# Get or create thread using helper function
thread_id = await get_or_create_thread_for_user(
thread_id=message_request.thread_id,
user=currentUser,
session=session,
thread_name=message_request.message[:100],
refresh_date_updated=True,
)
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
response = await chat_service.post_message(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
tool_ids=tool_ids,
)
return response
except PermissionError as e:
logger.error(f"Permission error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/threads", response_model=ThreadListResponse)
@limiter.limit("30/minute")
async def get_all_threads(
*,
request: Request,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ThreadListResponse:
"""
Get all chat threads for the current user.
"""
try:
# Get all threads for the current user
threads = await chat_service.get_all_threads_for_user(
user=currentUser, session=session
)
logger.info(f"User {currentUser.id} retrieved {len(threads)} threads")
return ThreadListResponse(threads=threads)
except Exception as e:
logger.error(
f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
@limiter.limit("30/minute")
async def get_thread_by_id(
*,
request: Request,
thread_id: str,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ThreadDetail:
"""
Get a specific chat thread with all its messages from LangGraph checkpointer.
"""
try:
thread_detail = await chat_service.get_thread_detail_for_user(
thread_id=thread_id,
user=currentUser,
session=session,
)
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
return thread_detail
except ValueError as e:
logger.error(f"Thread not found: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e) or "Thread not found",
)
except PermissionError as e:
logger.error(f"Permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.patch("/threads/{thread_id}", response_model=DeleteResponse)
@limiter.limit("30/minute")
async def rename_thread(
*,
request: Request,
thread_id: str,
rename_request: RenameThreadRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> DeleteResponse:
"""
Rename a chat thread.
"""
try:
await chat_service.update_thread_name(
thread_id=thread_id,
user=currentUser,
new_thread_name=rename_request.new_name,
session=session,
)
logger.info(
f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'"
)
return DeleteResponse(
message=f"Thread {thread_id} successfully renamed",
thread_id=thread_id,
)
except ValueError as e:
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e) or "Thread not found or permission denied",
)
except Exception as e:
logger.error(
f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.delete("/threads/{thread_id}", response_model=DeleteResponse)
@limiter.limit("10/minute")
async def delete_thread(
*,
request: Request,
thread_id: str,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> DeleteResponse:
"""
Delete a chat thread and all its associated data from both LangGraph and database.
"""
try:
await chat_service.delete_thread_for_user(
thread_id=thread_id,
user=currentUser,
session=session,
)
logger.info(f"User {currentUser.id} deleted thread {thread_id}")
return DeleteResponse(
message=f"Thread {thread_id} successfully deleted",
thread_id=thread_id,
)
except ValueError as e:
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e) or "Thread not found or permission denied",
)
except PermissionError as e:
logger.error(f"Permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
# Tool Management Endpoints
@router.get("/tools", response_model=ToolListResponse)
@limiter.limit("30/minute")
async def get_all_tools(
*,
request: Request,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ToolListResponse:
"""
Get all available chatbot tools.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can view tools",
)
# Get all tools from service
tools_data = await chat_service.get_all_tools(session=session)
# Convert to ToolInfo objects
tools = [ToolInfo(**tool) for tool in tools_data]
logger.info(f"User {currentUser.id} retrieved {len(tools)} tools")
return ToolListResponse(tools=tools)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.post("/tools/grant", response_model=GrantToolResponse)
@limiter.limit("10/minute")
async def grant_tool_to_user(
*,
request: Request,
grant_request: GrantToolRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> GrantToolResponse:
"""
Grant a tool to a user.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can grant tools",
)
# Grant the tool
await chat_service.grant_tool_to_user(
user_id=grant_request.user_id,
tool_id=grant_request.tool_id,
session=session,
)
logger.info(
f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}"
)
return GrantToolResponse(
message=f"Tool successfully granted to user {grant_request.user_id}",
user_id=grant_request.user_id,
tool_id=grant_request.tool_id,
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e) or "Invalid request",
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.delete("/tools/revoke", response_model=RevokeToolResponse)
@limiter.limit("10/minute")
async def revoke_tool_from_user(
*,
request: Request,
revoke_request: RevokeToolRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> RevokeToolResponse:
"""
Revoke a tool from a user.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can revoke tools",
)
# Revoke the tool
await chat_service.revoke_tool_from_user(
user_id=revoke_request.user_id,
tool_id=revoke_request.tool_id,
session=session,
)
logger.info(
f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}"
)
return RevokeToolResponse(
message=f"Tool successfully revoked from user {revoke_request.user_id}",
user_id=revoke_request.user_id,
tool_id=revoke_request.tool_id,
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e) or "Invalid request",
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse)
@limiter.limit("10/minute")
async def update_tool(
*,
request: Request,
tool_id: str,
update_request: UpdateToolRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> UpdateToolResponse:
"""
Update a tool's label and/or description.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can update tools",
)
# Update the tool
updated_fields = await chat_service.update_tool(
tool_id=tool_id,
label=update_request.label,
description=update_request.description,
session=session,
)
logger.info(
f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}"
)
return UpdateToolResponse(
message="Tool successfully updated",
tool_id=tool_id,
updated_fields=updated_fields,
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e) or "Invalid request",
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/tools/user/{user_id}", response_model=ToolListResponse)
@limiter.limit("30/minute")
async def get_tools_for_specific_user(
*,
request: Request,
user_id: str,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ToolListResponse:
"""
Get all tools granted to a specific user.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can view user tools",
)
# Get tools for the specified user
tools_data = await chat_service.get_tools_for_user(
user_id=user_id, session=session
)
# Convert to ToolInfo objects
tools = [ToolInfo(**tool) for tool in tools_data]
logger.info(
f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}"
)
return ToolListResponse(tools=tools)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/tools/me", response_model=ToolListResponse)
@limiter.limit("30/minute")
async def get_my_tools(
*,
request: Request,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ToolListResponse:
"""
Get all tools the current user has access to.
"""
try:
# Get tools for the current user
tools_data = await chat_service.get_tools_for_user(
user_id=currentUser.id, session=session
)
# Convert to ToolInfo objects
tools = [ToolInfo(**tool) for tool in tools_data]
logger.info(
f"User {currentUser.id} retrieved {len(tools)} tools for themselves"
)
return ToolListResponse(tools=tools)
except Exception as e:
logger.error(
f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
)

View file

@ -1,6 +1,5 @@
from typing import Dict, Any, List, Optional
import logging
from datetime import datetime, UTC
import uuid
import asyncio
import json