From f05d95821391787d630705b7408306b0a48dfde6 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Mon, 20 Oct 2025 22:01:03 +0200 Subject: [PATCH] removed chatbot on langgraph platform --- app.py | 2 - modules/datamodels/datamodelChatbot.py | 216 --- .../features/chatBot/chatbotTools/__init__.py | 1 - .../chatbotTools/customerTools/__init__.py | 1 - .../customerTools/toolQueryAlthausDatabase.py | 208 --- .../customerTools/toolValueOnPowerBi.py | 362 ----- .../chatbotTools/sharedTools/__init__.py | 7 - .../sharedTools/toolStreamingStatus.py | 24 - .../sharedTools/toolTavilySearch.py | 55 - modules/features/chatBot/domain/__init__.py | 1 - modules/features/chatBot/domain/chatbot.py | 301 ----- .../chatBot/domain/streaming_helper.py | 239 ---- modules/features/chatBot/mainChatBot.py | 1198 ----------------- .../features/chatBot/subChatbotDatabase.py | 197 --- .../features/chatBot/utils/checkpointer.py | 106 -- modules/features/chatBot/utils/permissions.py | 39 - .../features/chatBot/utils/toolRegistry.py | 305 ----- modules/features/featuresLifecycle.py | 13 +- modules/routes/routeChatbot.py | 653 --------- modules/workflows/workflowManager.py | 1 - 20 files changed, 7 insertions(+), 3922 deletions(-) delete mode 100644 modules/datamodels/datamodelChatbot.py delete mode 100644 modules/features/chatBot/chatbotTools/__init__.py delete mode 100644 modules/features/chatBot/chatbotTools/customerTools/__init__.py delete mode 100644 modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py delete mode 100644 modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py delete mode 100644 modules/features/chatBot/chatbotTools/sharedTools/__init__.py delete mode 100644 modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py delete mode 100644 modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py delete mode 100644 modules/features/chatBot/domain/__init__.py delete mode 100644 modules/features/chatBot/domain/chatbot.py delete mode 100644 modules/features/chatBot/domain/streaming_helper.py delete mode 100644 modules/features/chatBot/mainChatBot.py delete mode 100644 modules/features/chatBot/subChatbotDatabase.py delete mode 100644 modules/features/chatBot/utils/checkpointer.py delete mode 100644 modules/features/chatBot/utils/permissions.py delete mode 100644 modules/features/chatBot/utils/toolRegistry.py delete mode 100644 modules/routes/routeChatbot.py diff --git a/app.py b/app.py index c4b77472..2f7dbaf3 100644 --- a/app.py +++ b/app.py @@ -422,5 +422,3 @@ app.include_router(voiceGoogleRouter) from modules.routes.routeSecurityAdmin import router as adminSecurityRouter app.include_router(adminSecurityRouter) -# from modules.routes.routeChatbot import router as chatbotRouter -# app.include_router(chatbotRouter) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py deleted file mode 100644 index 762996b3..00000000 --- a/modules/datamodels/datamodelChatbot.py +++ /dev/null @@ -1,216 +0,0 @@ -"""Chatbot API models for request/response handling.""" - -from typing import List, Optional -from pydantic import BaseModel, Field -from modules.shared.attributeUtils import register_model_labels, ModelMixin - - -# Chatbot API Models -class MessageItem(BaseModel, ModelMixin): - """Individual message in a thread""" - - role: str = Field(..., description="Message role (user or assistant)") - content: str = Field(..., description="Message content") - - -class ChatMessageRequest(BaseModel, ModelMixin): - """Request model for posting a chat message""" - - thread_id: Optional[str] = Field( - None, description="Thread ID (creates new thread if not provided)" - ) - message: str = Field(..., description="User message content") - tools: Optional[List[str]] = Field( - None, - description="List of tool IDs to use. If not provided, all user's tools will be used", - ) - - -class ChatMessageResponse(BaseModel, ModelMixin): - """Response model for posting a chat message""" - - thread_id: str = Field(..., description="Thread ID") - messages: List[MessageItem] = Field(..., description="All messages in thread") - - -class ThreadSummary(BaseModel, ModelMixin): - """Summary of a chat thread for list view""" - - thread_id: str = Field(..., description="Thread ID") - thread_name: str = Field(..., description="Thread name") - date_created: float = Field(..., description="Thread creation timestamp") - date_updated: float = Field(..., description="Thread last updated timestamp") - - -class ThreadListResponse(BaseModel, ModelMixin): - """Response model for listing all threads""" - - threads: List[ThreadSummary] = Field(..., description="List of thread summaries") - - -class ThreadDetail(BaseModel, ModelMixin): - """Detailed view of a single thread""" - - thread_id: str = Field(..., description="Thread ID") - date_created: float = Field(..., description="Thread creation timestamp") - date_updated: float = Field(..., description="Thread last updated timestamp") - messages: List[MessageItem] = Field( - ..., description="All messages in chronological order" - ) - - -class RenameThreadRequest(BaseModel, ModelMixin): - """Request model for renaming a thread""" - - new_name: str = Field(..., description="New name for the thread") - - -class DeleteResponse(BaseModel, ModelMixin): - """Response model for delete operations""" - - message: str = Field(..., description="Confirmation message") - thread_id: str = Field(..., description="Deleted thread ID") - - -# Tool Management Models -class ToolInfo(BaseModel, ModelMixin): - """Information about a chatbot tool""" - - id: str = Field(..., description="Tool UUID") - tool_id: str = Field( - ..., description="Tool identifier (e.g., 'shared.tavily_search')" - ) - name: str = Field(..., description="Tool function name") - label: str = Field(..., description="Display label for the tool") - category: str = Field(..., description="Tool category (shared or customer)") - description: str = Field(..., description="Tool description") - is_active: bool = Field(..., description="Whether the tool is active") - date_created: float = Field(..., description="Creation timestamp") - date_updated: float = Field(..., description="Last update timestamp") - - -class ToolListResponse(BaseModel, ModelMixin): - """Response model for listing all tools""" - - tools: List[ToolInfo] = Field(..., description="List of available tools") - - -class GrantToolRequest(BaseModel, ModelMixin): - """Request model for granting a tool to a user""" - - user_id: str = Field(..., description="User ID to grant the tool to") - tool_id: str = Field(..., description="Tool UUID from tools table") - - -class GrantToolResponse(BaseModel, ModelMixin): - """Response model after granting a tool""" - - message: str = Field(..., description="Confirmation message") - user_id: str = Field(..., description="User ID") - tool_id: str = Field(..., description="Tool UUID") - - -class RevokeToolRequest(BaseModel, ModelMixin): - """Request model for revoking a tool from a user""" - - user_id: str = Field(..., description="User ID to revoke the tool from") - tool_id: str = Field(..., description="Tool UUID from tools table") - - -class RevokeToolResponse(BaseModel, ModelMixin): - """Response model after revoking a tool""" - - message: str = Field(..., description="Confirmation message") - user_id: str = Field(..., description="User ID") - tool_id: str = Field(..., description="Tool UUID") - - -class UpdateToolRequest(BaseModel, ModelMixin): - """Request model for updating a tool's label and description""" - - label: Optional[str] = Field(None, description="New label for the tool") - description: Optional[str] = Field(None, description="New description for the tool") - - -class UpdateToolResponse(BaseModel, ModelMixin): - """Response model after updating a tool""" - - message: str = Field(..., description="Confirmation message") - tool_id: str = Field(..., description="Tool UUID") - updated_fields: List[str] = Field(..., description="List of updated field names") - - -# Register model labels for internationalization -register_model_labels( - "MessageItem", - {"en": "Message Item", "fr": "Élément de message"}, - { - "role": {"en": "Role", "fr": "Rôle"}, - "content": {"en": "Content", "fr": "Contenu"}, - }, -) - -register_model_labels( - "ChatMessageRequest", - {"en": "Chat Message Request", "fr": "Demande de message de chat"}, - { - "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, - "message": {"en": "Message", "fr": "Message"}, - }, -) - -register_model_labels( - "ChatMessageResponse", - {"en": "Chat Message Response", "fr": "Réponse du message de chat"}, - { - "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, - "messages": {"en": "Messages", "fr": "Messages"}, - }, -) - -register_model_labels( - "ThreadSummary", - {"en": "Thread Summary", "fr": "Résumé du fil"}, - { - "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, - "thread_name": {"en": "Thread Name", "fr": "Nom du fil"}, - "date_created": {"en": "Date Created", "fr": "Date de création"}, - "date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"}, - }, -) - -register_model_labels( - "ThreadListResponse", - {"en": "Thread List Response", "fr": "Réponse de liste de fils"}, - { - "threads": {"en": "Threads", "fr": "Fils"}, - }, -) - -register_model_labels( - "ThreadDetail", - {"en": "Thread Detail", "fr": "Détail du fil"}, - { - "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, - "date_created": {"en": "Date Created", "fr": "Date de création"}, - "date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"}, - "messages": {"en": "Messages", "fr": "Messages"}, - }, -) - -register_model_labels( - "RenameThreadRequest", - {"en": "Rename Thread Request", "fr": "Demande de renommage de fil"}, - { - "new_name": {"en": "New Name", "fr": "Nouveau nom"}, - }, -) - -register_model_labels( - "DeleteResponse", - {"en": "Delete Response", "fr": "Réponse de suppression"}, - { - "message": {"en": "Message", "fr": "Message"}, - "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, - }, -) diff --git a/modules/features/chatBot/chatbotTools/__init__.py b/modules/features/chatBot/chatbotTools/__init__.py deleted file mode 100644 index 2bd4359d..00000000 --- a/modules/features/chatBot/chatbotTools/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Contains all tools available for the chatbot to use.""" diff --git a/modules/features/chatBot/chatbotTools/customerTools/__init__.py b/modules/features/chatBot/chatbotTools/customerTools/__init__.py deleted file mode 100644 index 52043b31..00000000 --- a/modules/features/chatBot/chatbotTools/customerTools/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tools that are shared between multiple customers go here.""" diff --git a/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py b/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py deleted file mode 100644 index 72c15f15..00000000 --- a/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Althaus Database Query Tool for LangGraph. - -This tool provides database query capabilities for the Althaus database -via an external REST API. Only SELECT queries are allowed. -""" - -import logging -import asyncio -import re -from typing import Annotated -from langchain_core.tools import tool - -logger = logging.getLogger(__name__) - - -async def _mock_api_call(*, sql_query: str) -> dict: - """Mock the external REST API call to Althaus database. - - Args: - sql_query: The SQL SELECT query to execute - - Returns: - A dictionary containing the query results with columns and rows - """ - # Simulate network delay - await asyncio.sleep(0.5) - - # Mock response data based on common query patterns - if "users" in sql_query.lower(): - return { - "columns": ["id", "username", "email", "created_at"], - "rows": [ - [1, "john_doe", "john@example.com", "2024-01-15"], - [2, "jane_smith", "jane@example.com", "2024-02-20"], - [3, "bob_wilson", "bob@example.com", "2024-03-10"], - ], - "row_count": 3, - } - elif "products" in sql_query.lower(): - return { - "columns": ["product_id", "name", "price", "stock"], - "rows": [ - [101, "Widget A", 29.99, 150], - [102, "Widget B", 39.99, 75], - [103, "Widget C", 19.99, 200], - ], - "row_count": 3, - } - elif "orders" in sql_query.lower(): - return { - "columns": ["order_id", "customer_id", "total", "status"], - "rows": [ - [5001, 1, 129.99, "completed"], - [5002, 2, 89.50, "pending"], - [5003, 1, 199.99, "shipped"], - ], - "row_count": 3, - } - else: - # Generic response for other queries - return { - "columns": ["id", "value", "description"], - "rows": [ - [1, "Sample 1", "First sample entry"], - [2, "Sample 2", "Second sample entry"], - ], - "row_count": 2, - } - - -def _validate_select_query(*, sql_query: str) -> tuple[bool, str]: - """Validate that the query is a SELECT statement only. - - Args: - sql_query: The SQL query to validate - - Returns: - A tuple of (is_valid, error_message) - """ - # Remove leading/trailing whitespace and convert to lowercase for checking - normalized_query = sql_query.strip().lower() - - # Check if query starts with SELECT - if not normalized_query.startswith("select"): - return False, "Query must be a SELECT statement" - - # Check for dangerous keywords that should not be in a SELECT query - dangerous_keywords = [ - "insert", - "update", - "delete", - "drop", - "create", - "alter", - "truncate", - "grant", - "revoke", - "exec", - "execute", - ] - - for keyword in dangerous_keywords: - # Use word boundary to match whole words only - if re.search(rf"\b{keyword}\b", normalized_query): - return False, f"Query contains forbidden keyword: {keyword.upper()}" - - return True, "" - - -def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str: - """Format query results into a readable string. - - Args: - columns: List of column names - rows: List of row data - row_count: Total number of rows - - Returns: - Formatted string representation of the results - """ - if row_count == 0: - return "Query executed successfully but returned no results." - - # Calculate column widths - col_widths = [len(str(col)) for col in columns] - for row in rows: - for i, cell in enumerate(row): - col_widths[i] = max(col_widths[i], len(str(cell))) - - # Build header - header_parts = [] - for col, width in zip(columns, col_widths): - header_parts.append(str(col).ljust(width)) - header = " | ".join(header_parts) - separator = "-" * len(header) - - # Build rows - row_lines = [] - for row in rows: - row_parts = [] - for cell, width in zip(row, col_widths): - row_parts.append(str(cell).ljust(width)) - row_lines.append(" | ".join(row_parts)) - - # Combine all parts - result_parts = [ - f"Query returned {row_count} row(s):\n", - header, - separator, - "\n".join(row_lines), - ] - - return "\n".join(result_parts) - - -@tool -async def query_althaus_database( - sql_query: Annotated[ - str, "The SQL SELECT query to execute against the Althaus database" - ], -) -> str: - """Execute a SELECT query against the Althaus database via REST API. - - Use this tool to query data from the Althaus database. Only SELECT statements - are allowed for security reasons. The query will be forwarded to an external - REST API and the results will be returned in a formatted table. - - Args: - sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1") - - Returns: - A formatted string containing the query results with columns and rows - """ - try: - # Validate the query - is_valid, error_msg = _validate_select_query(sql_query=sql_query) - if not is_valid: - logger.warning(f"Invalid query attempt: {sql_query[:100]}...") - return f"Error: {error_msg}" - - logger.info(f"Executing Althaus database query: {sql_query[:100]}...") - - # Mock the external REST API call - # In production, this would be replaced with actual REST API call: - # response = await httpx.AsyncClient().post( - # "https://api.althaus.example.com/query", - # json={"query": sql_query}, - # headers={"Authorization": f"Bearer {api_key}"} - # ) - # result = response.json() - - result = await _mock_api_call(sql_query=sql_query) - - # Format and return results - formatted_output = _format_results( - columns=result["columns"], - rows=result["rows"], - row_count=result["row_count"], - ) - - logger.info( - f"Query completed successfully, returned {result['row_count']} row(s)" - ) - return formatted_output - - except Exception as e: - logger.error(f"Error in query_althaus_database tool: {str(e)}") - return f"Error executing query: {str(e)}" diff --git a/modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py b/modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py deleted file mode 100644 index 7c00ee6a..00000000 --- a/modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Power BI Query Tool for LangGraph. - -This tool provides DAX query capabilities for Power BI datasets -via the Power BI REST API. Only read-only queries are allowed. -""" - -import logging -import asyncio -import os -import re -import functools -from typing import Annotated - -import anyio -import httpx -from langchain_core.tools import tool -from msal import ConfidentialClientApplication, SerializableTokenCache - -from modules.shared.configuration import APP_CONFIG - -logger = logging.getLogger(__name__) - - -# Configuration constants - encapsulated in this file -POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "") -POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "") -POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "") -POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "") -POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg" -POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com" -POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"] - -# Limit results to prevent excessive context usage -MAX_ROWS_LIMIT = 100 - - -def _validate_environment() -> tuple[bool, str]: - """Validate that all required environment variables are set. - - Returns: - A tuple of (is_valid, error_message) - """ - missing = [] - if not POWERBI_DATASET_ID: - missing.append("POWERBI_DATASET_ID") - if not POWERBI_CLIENT_ID: - missing.append("POWERBI_CLIENT_ID") - if not POWERBI_CLIENT_SECRET: - missing.append("POWERBI_CLIENT_SECRET") - if not POWERBI_TENANT_ID: - missing.append("POWERBI_TENANT_ID") - - if missing: - return False, f"Missing required environment variables: {', '.join(missing)}" - - return True, "" - - -def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]: - """Validate that the query is a valid DAX query. - - Args: - dax_query: The DAX query to validate - - Returns: - A tuple of (is_valid, error_message) - """ - # Remove leading/trailing whitespace - normalized_query = dax_query.strip() - - if not normalized_query: - return False, "Query cannot be empty" - - # DAX queries typically start with EVALUATE, DEFINE, or are table expressions - # We'll be lenient and just check it's not trying to do something dangerous - # DAX is read-only by nature, but we validate structure - - # Check for minimum length - if len(normalized_query) < 5: - return False, "Query is too short to be valid" - - return True, "" - - -def _get_access_token_sync( - *, - tenant_id: str, - client_id: str, - client_secret: str, - authority_base: str = POWERBI_AUTHORITY_BASE, - cache: SerializableTokenCache | None = None, -) -> str: - """Get Power BI access token using MSAL (synchronous). - - Args: - tenant_id: Azure AD tenant ID - client_id: Application client ID - client_secret: Application client secret - authority_base: Azure AD authority base URL - cache: Optional token cache for reuse - - Returns: - Access token string - - Raises: - RuntimeError: If token acquisition fails - """ - authority = f"{authority_base}/{tenant_id}" - - app = ConfidentialClientApplication( - client_id=client_id, - authority=authority, - client_credential=client_secret, - token_cache=cache, - ) - - # Try cache first; fall back to client credentials - result = app.acquire_token_silent( - POWERBI_SCOPE, account=None - ) or app.acquire_token_for_client(scopes=POWERBI_SCOPE) - - if "access_token" not in result: - raise RuntimeError( - f"MSAL token error: {result.get('error')} - {result.get('error_description')}" - ) - - return result["access_token"] - - -async def _get_access_token_async( - *, - tenant_id: str, - client_id: str, - client_secret: str, - **kwargs, -) -> str: - """Get Power BI access token using MSAL (asynchronous). - - Args: - tenant_id: Azure AD tenant ID - client_id: Application client ID - client_secret: Application client secret - **kwargs: Additional arguments for _get_access_token_sync - - Returns: - Access token string - """ - # Create a partial function with arguments pre-filled - func = functools.partial( - _get_access_token_sync, - tenant_id=tenant_id, - client_id=client_id, - client_secret=client_secret, - **kwargs, - ) - # Offload the blocking MSAL HTTP call to a worker thread - return await anyio.to_thread.run_sync(func) - - -async def _execute_dax_query( - *, dax_query: str, dataset_id: str, access_token: str -) -> dict: - """Execute a DAX query against Power BI dataset. - - Args: - dax_query: The DAX query to execute - dataset_id: Power BI dataset ID - access_token: Access token for authentication - - Returns: - Dictionary containing query results - - Raises: - RuntimeError: If query execution fails - """ - url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries" - - body = { - "queries": [{"query": dax_query}], - "serializerSettings": {"includeNulls": True}, - } - - headers = { - "Authorization": f"Bearer {access_token}", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(url, headers=headers, json=body) - - if resp.status_code != 200: - raise RuntimeError( - f"Power BI executeQueries failed: {resp.status_code} - {resp.text}" - ) - - payload = resp.json() - - try: - rows = payload["results"][0]["tables"][0]["rows"] - except (KeyError, IndexError) as e: - raise RuntimeError("Unexpected executeQueries response structure") from e - - # Extract column names from the first row if available - if rows: - columns = list(rows[0].keys()) - else: - columns = [] - - return {"columns": columns, "rows": rows} - - -def _strip_table_qualifier(*, column_name: str) -> str: - """Strip table qualifier from column name. - - Power BI often returns columns as 'Table[Column]'. This strips to 'Column'. - - Args: - column_name: The column name to process - - Returns: - Processed column name - """ - if "[" in column_name and column_name.endswith("]"): - return column_name.split("[", 1)[1][:-1] - return column_name - - -def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str: - """Format query results into a readable string. - - Args: - columns: List of column names - rows: List of row data (as dictionaries) - max_rows: Maximum number of rows to display - - Returns: - Formatted string representation of the results - """ - total_rows = len(rows) - - if total_rows == 0: - return "Query executed successfully but returned no results." - - # Strip table qualifiers from column names - clean_columns = [_strip_table_qualifier(column_name=col) for col in columns] - - # Limit rows to max_rows - display_rows = rows[:max_rows] - truncated = total_rows > max_rows - - # Calculate column widths - col_widths = [len(str(col)) for col in clean_columns] - for row in display_rows: - for i, col in enumerate(columns): - value = row.get(col, "") - col_widths[i] = max(col_widths[i], len(str(value))) - - # Build header - header_parts = [] - for col, width in zip(clean_columns, col_widths): - header_parts.append(str(col).ljust(width)) - header = " | ".join(header_parts) - separator = "-" * len(header) - - # Build rows - row_lines = [] - for row in display_rows: - row_parts = [] - for col, width in zip(columns, col_widths): - value = row.get(col, "") - row_parts.append(str(value).ljust(width)) - row_lines.append(" | ".join(row_parts)) - - # Combine all parts - result_parts = [ - f"Query returned {total_rows} row(s):", - ] - - if truncated: - result_parts.append( - f"(Results limited to {max_rows} rows for context efficiency)\n" - ) - else: - result_parts.append("") - - result_parts.extend([header, separator, "\n".join(row_lines)]) - - return "\n".join(result_parts) - - -@tool -async def query_powerbi_data( - dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"], -) -> str: - """Execute a DAX query against the Power BI dataset to access warehouse inventory data. - - This tool provides access to a Power BI table called 'data_full' which contains - articles available in the warehouse of the user. Use DAX (Data Analysis Expressions) - queries to retrieve and analyze this inventory data. - - Available table: - - 'data_full': Contains warehouse inventory articles and their details - - Common query patterns: - - View all data: EVALUATE 'data_full' - - With filter: EVALUATE FILTER('data_full', [Column] = "Value") - - Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC) - - Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2])) - - Results are limited to 100 rows maximum for efficiency. - - Args: - dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'") - - Returns: - A formatted string containing the query results with columns and rows - """ - try: - # Validate environment configuration - is_valid_env, error_msg = _validate_environment() - if not is_valid_env: - logger.error(f"Environment validation failed: {error_msg}") - return f"Configuration Error: {error_msg}" - - # Validate the query - is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query) - if not is_valid_query: - logger.warning(f"Invalid query attempt: {dax_query[:100]}...") - return f"Query Validation Error: {error_msg}" - - logger.info(f"Executing Power BI query: {dax_query[:100]}...") - - # Get access token - access_token = await _get_access_token_async( - tenant_id=POWERBI_TENANT_ID, - client_id=POWERBI_CLIENT_ID, - client_secret=POWERBI_CLIENT_SECRET, - ) - - # Execute the query - result = await _execute_dax_query( - dax_query=dax_query, - dataset_id=POWERBI_DATASET_ID, - access_token=access_token, - ) - - # Format and return results - formatted_output = _format_results( - columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT - ) - - logger.info( - f"Query completed successfully, returned {len(result['rows'])} row(s)" - ) - return formatted_output - - except RuntimeError as e: - logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}") - return f"Error executing query: {str(e)}" - except Exception as e: - logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}") - return f"Unexpected error: {str(e)}" diff --git a/modules/features/chatBot/chatbotTools/sharedTools/__init__.py b/modules/features/chatBot/chatbotTools/sharedTools/__init__.py deleted file mode 100644 index 9b0ab5b7..00000000 --- a/modules/features/chatBot/chatbotTools/sharedTools/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Shared tools available across all chatbot implementations.""" - -from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import ( - tavily_search, -) - -__all__ = ["tavily_search"] diff --git a/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py b/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py deleted file mode 100644 index f2587be8..00000000 --- a/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Tool for sending streaming status updates to users.""" - -from langchain_core.tools import tool - - -@tool -def send_streaming_message(message: str) -> str: - """Send a streaming message to the user to provide updates during processing. - - Use this tool to send short status updates to the user while you are working - on their request. This helps keep the user informed about what you are doing. - - Args: - message: A short message describing what you are currently doing. - Examples: "Searching database for relevant information..." - "Analyzing search results..." - "Processing your request..." - - Returns: - A confirmation that the message was sent. - """ - # This tool doesn't actually do anything - it's just for the AI to signal - # what it's doing to the frontend via the tool call mechanism - return f"Status update sent: {message}" diff --git a/modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py b/modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py deleted file mode 100644 index e3a6c8fc..00000000 --- a/modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Tavily Search Tool for LangGraph. - -This tool provides web search capabilities using the Tavily API. -""" - -import logging -from typing import Annotated -from langchain_core.tools import tool -from modules.connectors.connectorAiTavily import ConnectorWeb - -logger = logging.getLogger(__name__) - - -@tool -async def tavily_search( - query: Annotated[str, "The search query to look up on the web"], -) -> str: - """Search the web using Tavily API. - - Use this tool to search for current information, news, or any web content. - The tool returns relevant search results including titles and URLs. - - Args: - query: The search query string - - Returns: - A formatted string containing search results with titles and URLs - """ - try: - # Create connector instance - connector = await ConnectorWeb.create() - - # Perform search with default parameters - results = await connector._search( - query=query, - max_results=5, - search_depth="basic", - include_answer=True, - include_raw_content=False, - ) - - # Format results - if not results: - return f"No results found for query: {query}" - - formatted_results = [f"Search results for '{query}':\n"] - for i, result in enumerate(results, 1): - formatted_results.append(f"{i}. {result.title}") - formatted_results.append(f" URL: {result.url}\n") - - return "\n".join(formatted_results) - - except Exception as e: - logger.error(f"Error in tavily_search tool: {str(e)}") - return f"Error performing search: {str(e)}" diff --git a/modules/features/chatBot/domain/__init__.py b/modules/features/chatBot/domain/__init__.py deleted file mode 100644 index abd60dca..00000000 --- a/modules/features/chatBot/domain/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Domain logic for chatbot functionality.""" diff --git a/modules/features/chatBot/domain/chatbot.py b/modules/features/chatBot/domain/chatbot.py deleted file mode 100644 index 548ea966..00000000 --- a/modules/features/chatBot/domain/chatbot.py +++ /dev/null @@ -1,301 +0,0 @@ -"""Chatbot domain logic with LangGraph integration.""" - -from dataclasses import dataclass -from typing import Annotated, AsyncIterator, Any -import logging - -from pydantic import BaseModel -from langchain_core.messages import ( - BaseMessage, - HumanMessage, - SystemMessage, - trim_messages, -) -from langgraph.graph.message import add_messages -from langgraph.graph import StateGraph, START, END -from langgraph.graph.state import CompiledStateGraph -from langgraph.prebuilt import ToolNode -from langchain_anthropic import ChatAnthropic - -from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper -from modules.features.chatBot.utils.toolRegistry import get_registry -from modules.shared.configuration import APP_CONFIG - -logger = logging.getLogger(__name__) - - -class ChatState(BaseModel): - """Represents the state of a chat session.""" - - messages: Annotated[list[BaseMessage], add_messages] - - -def get_langchain_model(*, model_name: str) -> ChatAnthropic: - """Map permission model names to LangChain ChatAnthropic models. - - Args: - model_name: The model name from permissions (e.g., "claude_4_5") - - Returns: - Configured ChatAnthropic instance - - Raises: - ValueError: If the model name is not supported - """ - # Model name mapping - model_mapping = { - "claude_4_5": "claude-sonnet-4-5", - # Add more mappings as needed - } - - anthropic_model = model_mapping.get(model_name) - if not anthropic_model: - logger.warning( - f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet" - ) - anthropic_model = "claude-4-5-sonnet" - - return ChatAnthropic( - model=anthropic_model, - api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"), - temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)), - max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)), - ) - - -@dataclass -class Chatbot: - """Represents a chatbot with LangGraph integration.""" - - model: Any - memory: Any - app: Any = None - system_prompt: str = "You are a helpful assistant." - context_window_size: int = 100000 - - @classmethod - async def create( - cls, - *, - model: Any, - memory: Any, - system_prompt: str, - tools: list, - context_window_size: int = 100000, - ) -> "Chatbot": - """Factory method to create and configure a Chatbot instance. - - Args: - model: The chat model to use. - memory: The chat memory checkpointer to use. - system_prompt: The system prompt to initialize the chatbot. - tools: List of LangChain tools the chatbot can use. - context_window_size: Maximum tokens for context window. - - Returns: - A configured Chatbot instance. - """ - instance = cls( - model=model, - memory=memory, - system_prompt=system_prompt, - context_window_size=context_window_size, - ) - instance.app = instance._build_app(memory=memory, tools=tools) - return instance - - def _build_app( - self, *, memory: Any, tools: list - ) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]: - """Builds the chatbot application workflow using LangGraph. - - Args: - memory: The chat memory checkpointer to use. - tools: The list of tools the chatbot can use. - - Returns: - A compiled state graph representing the chatbot application. - """ - llm_with_tools = self.model.bind_tools(tools=tools) - - def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]: - """Selects a window of messages that fit within the context window size. - - Args: - msgs: The list of messages to select from. - - Returns: - A list of messages that fit within the context window size. - """ - - def approx_counter(items: list[BaseMessage]) -> int: - """Approximate token counter for messages. - - Args: - items: List of messages to count tokens for. - - Returns: - Approximate number of tokens in the messages. - """ - return sum(len(getattr(m, "content", "") or "") for m in items) - - return trim_messages( - msgs, - strategy="last", - token_counter=approx_counter, - max_tokens=self.context_window_size, - start_on="human", - end_on=("human", "tool"), - include_system=True, - ) - - def agent_node(state: ChatState) -> dict: - """Agent node for the chatbot workflow. - - Args: - state: The current chat state. - - Returns: - The updated chat state after processing. - """ - # Select the message window to fit in context (trim if needed) - window = select_window(state.messages) - - # Ensure the system prompt is present at the start - if not window or not isinstance(window[0], SystemMessage): - window = [SystemMessage(content=self.system_prompt)] + window - - # Call the LLM with tools - response = llm_with_tools.invoke(window) - - # Return the new state - return {"messages": [response]} - - def should_continue(state: ChatState) -> str: - """Determines whether to continue the workflow or end it. - - This conditional edge is called after the agent node to decide - whether to continue to the tools node (if the last message contains - tool calls) or to end the workflow (if no tool calls are present). - - Args: - state: The current chat state. - - Returns: - The next node to transition to ("tools" or END). - """ - # Get the last message - last_message = state.messages[-1] - - # Check if the last message contains tool calls - # If so, continue to the tools node; otherwise, end the workflow - return "tools" if getattr(last_message, "tool_calls", None) else END - - # Compose the workflow - workflow = StateGraph(ChatState) - workflow.add_node("agent", agent_node) - workflow.add_node("tools", ToolNode(tools=tools)) - workflow.add_edge(START, "agent") - workflow.add_conditional_edges("agent", should_continue) - workflow.add_edge("tools", "agent") - return workflow.compile(checkpointer=memory) - - async def chat( - self, *, message: str, chat_id: str = "default" - ) -> list[BaseMessage]: - """Processes a chat message and returns the chat history. - - Args: - message: The user message to process. - chat_id: The chat thread ID. - - Returns: - The list of messages in the chat history. - """ - # Set the right thread ID for memory - config = {"configurable": {"thread_id": chat_id}} - - # Single-turn chat (non-streaming) - result = await self.app.ainvoke( - {"messages": [HumanMessage(content=message)]}, config=config - ) - - # Extract and return the messages from the result - return result["messages"] - - async def stream_events( - self, *, message: str, chat_id: str = "default" - ) -> AsyncIterator[dict]: - """Stream UI-focused events using astream_events v2. - - Args: - message: The user message to process. - chat_id: Logical thread identifier; forwarded in the runnable config so - memory and tools are scoped per thread. - - Yields: - dict: One of: - - ``{"type": "status", "label": str}`` for short progress updates. - - ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}`` - where ``chat_history`` only includes ``user``/``assistant`` roles. - - ``{"type": "error", "message": str}`` if an exception occurs. - """ - # Thread-aware config for LangGraph/LangChain - config = {"configurable": {"thread_id": chat_id}} - - def _is_root(ev: dict) -> bool: - """Return True if the event is from the root run (v2: empty parent_ids).""" - return not ev.get("parent_ids") - - try: - async for event in self.app.astream_events( - {"messages": [HumanMessage(content=message)]}, - config=config, - version="v2", - ): - etype = event.get("event") - ename = event.get("name") or "" - edata = event.get("data") or {} - - # Stream human-readable progress via the special send_streaming_message tool - if etype == "on_tool_start" and ename == "send_streaming_message": - tool_in = edata.get("input") or {} - msg = tool_in.get("message") - if isinstance(msg, str) and msg.strip(): - yield {"type": "status", "label": msg.strip()} - continue - - # Emit the final payload when the root run finishes - if etype == "on_chain_end" and _is_root(event): - output_obj = edata.get("output") - - # Extract message list from the graph's final output - final_msgs = ChatStreamingHelper.extract_messages_from_output( - output_obj=output_obj - ) - - # Normalize for the frontend (only user/assistant with text content) - chat_history_payload: list[dict] = [] - for m in final_msgs: - if isinstance(m, BaseMessage): - d = ChatStreamingHelper.message_to_dict(msg=m) - elif isinstance(m, dict): - d = ChatStreamingHelper.dict_message_to_dict(obj=m) - else: - continue - if d.get("role") in ("user", "assistant") and d.get("content"): - chat_history_payload.append(d) - - yield { - "type": "final", - "response": { - "thread": chat_id, - "chat_history": chat_history_payload, - }, - } - return - - except Exception as exc: - # Emit a single error envelope and end the stream - logger.error(f"Error in stream_events: {str(exc)}", exc_info=True) - yield {"type": "error", "message": f"Error processing request: {exc}"} diff --git a/modules/features/chatBot/domain/streaming_helper.py b/modules/features/chatBot/domain/streaming_helper.py deleted file mode 100644 index f8c73b45..00000000 --- a/modules/features/chatBot/domain/streaming_helper.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Streaming helper utilities for chat message processing and normalization.""" - -from typing import Any, Dict, List, Literal, Mapping, Optional - -from langchain_core.messages import ( - AIMessage, - BaseMessage, - HumanMessage, - SystemMessage, - ToolMessage, -) - -Role = Literal["user", "assistant", "system", "tool"] - - -class ChatStreamingHelper: - """Pure helper methods for streaming and message normalization. - - This class provides static utility methods for converting between different - message formats, extracting content, and normalizing message structures - for streaming chat applications. - """ - - @staticmethod - def role_from_message(*, msg: BaseMessage) -> Role: - """Extract the role from a BaseMessage instance. - - Args: - msg: The BaseMessage instance to extract the role from. - - Returns: - The role as a string literal: "user", "assistant", "system", or "tool". - Defaults to "assistant" if the message type is not recognized. - - Examples: - >>> from langchain_core.messages import HumanMessage - >>> msg = HumanMessage(content="Hello") - >>> ChatStreamingHelper.role_from_message(msg=msg) - 'user' - """ - if isinstance(msg, HumanMessage): - return "user" - if isinstance(msg, AIMessage): - return "assistant" - if isinstance(msg, SystemMessage): - return "system" - if isinstance(msg, ToolMessage): - return "tool" - return getattr(msg, "role", "assistant") - - @staticmethod - def flatten_content(*, content: Any) -> str: - """Convert complex content structures to plain text. - - This method handles various content formats including strings, lists of - content parts, and dictionaries with text fields. It's designed to - normalize content from different message sources into a consistent - plain text format. - - Args: - content: The content to flatten. Can be: - - str: Returned as-is after stripping whitespace - - list: Each item processed and joined with newlines - - dict: Text extracted from "text" or "content" fields - - None: Returns empty string - - Any other type: Converted to string - - Returns: - The flattened content as a plain text string with whitespace stripped. - - Examples: - >>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}] - >>> ChatStreamingHelper.flatten_content(content=content) - 'Hello - nworld' - - >>> content = {"text": "Simple message"} - >>> ChatStreamingHelper.flatten_content(content=content) - 'Simple message' - """ - if content is None: - return "" - if isinstance(content, str): - return content.strip() - if isinstance(content, list): - parts: List[str] = [] - for part in content: - if isinstance(part, dict): - if "text" in part and isinstance(part["text"], str): - parts.append(part["text"]) - elif part.get("type") == "text" and isinstance( - part.get("text"), str - ): - parts.append(part["text"]) - elif "content" in part and isinstance(part["content"], str): - parts.append(part["content"]) - else: - # Fallback for unknown dictionary structures - val = part.get("value") - if isinstance(val, str): - parts.append(val) - else: - parts.append(str(part)) - return "\n".join(p.strip() for p in parts if p is not None) - if isinstance(content, dict): - if "text" in content and isinstance(content["text"], str): - return content["text"].strip() - if "content" in content and isinstance(content["content"], str): - return content["content"].strip() - return str(content).strip() - - @staticmethod - def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]: - """Convert a BaseMessage instance to a dictionary for streaming output. - - This method normalizes BaseMessage instances into a consistent dictionary - format suitable for JSON serialization and streaming to clients. - - Args: - msg: The BaseMessage instance to convert. - - Returns: - A dictionary containing: - - "role": The message role (user, assistant, system, tool) - - "content": The flattened message content as plain text - - "tool_calls": Tool calls if present (optional) - - "name": Message name if present (optional) - - Examples: - >>> from langchain_core.messages import HumanMessage - >>> msg = HumanMessage(content="Hello there") - >>> result = ChatStreamingHelper.message_to_dict(msg=msg) - >>> result["role"] - 'user' - >>> result["content"] - 'Hello there' - """ - payload: Dict[str, Any] = { - "role": ChatStreamingHelper.role_from_message(msg=msg), - "content": ChatStreamingHelper.flatten_content( - content=getattr(msg, "content", "") - ), - } - tool_calls = getattr(msg, "tool_calls", None) - if tool_calls: - payload["tool_calls"] = tool_calls - name = getattr(msg, "name", None) - if name: - payload["name"] = name - return payload - - @staticmethod - def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]: - """Convert a dictionary-shaped message to a normalized dictionary. - - This method handles messages that come from serialized state and are - represented as dictionaries rather than BaseMessage instances. It - normalizes various dictionary formats into a consistent structure. - - Args: - obj: The dictionary-shaped message to convert. Expected to contain - fields like "role", "type", "content", "text", etc. - - Returns: - A normalized dictionary containing: - - "role": The message role (user, assistant, system, tool) - - "content": The flattened message content as plain text - - "tool_calls": Tool calls if present (optional) - - "name": Message name if present (optional) - - Examples: - >>> obj = {"type": "human", "content": "Hello"} - >>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj) - >>> result["role"] - 'user' - >>> result["content"] - 'Hello' - """ - role: Optional[str] = obj.get("role") - if not role: - # Handle alternative type field mappings - typ = obj.get("type") - if typ in ("human", "user"): - role = "user" - elif typ in ("ai", "assistant"): - role = "assistant" - elif typ in ("system",): - role = "system" - elif typ in ("tool", "function"): - role = "tool" - - content = obj.get("content") - if content is None and "text" in obj: - content = obj["text"] - - out: Dict[str, Any] = { - "role": role or "assistant", - "content": ChatStreamingHelper.flatten_content(content=content), - } - if "tool_calls" in obj: - out["tool_calls"] = obj["tool_calls"] - if obj.get("name"): - out["name"] = obj["name"] - return out - - @staticmethod - def extract_messages_from_output(*, output_obj: Any) -> List[Any]: - """Extract messages from LangGraph output objects. - - This method handles various output formats from LangGraph execution, - extracting the messages list from different possible structures. - - Args: - output_obj: The output object from LangGraph execution. Can be: - - An object with a "messages" attribute - - A dictionary with a "messages" key - - Any other object (returns empty list) - - Returns: - A list of extracted messages, or an empty list if no messages - are found or if the output object is None. - - Examples: - >>> output = {"messages": [{"role": "user", "content": "Hello"}]} - >>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output) - >>> len(messages) - 1 - """ - if output_obj is None: - return [] - - # Try to parse dicts first - if isinstance(output_obj, dict): - msgs = output_obj.get("messages") - return msgs if isinstance(msgs, list) else [] - - # Then try to get messages attribute - msgs = getattr(output_obj, "messages", None) - return msgs if isinstance(msgs, list) else [] diff --git a/modules/features/chatBot/mainChatBot.py b/modules/features/chatBot/mainChatBot.py deleted file mode 100644 index 2d44e422..00000000 --- a/modules/features/chatBot/mainChatBot.py +++ /dev/null @@ -1,1198 +0,0 @@ -"""Service layer for chatbot functionality.""" - -import json -import asyncio -import logging -from datetime import datetime, timezone -import sys -from typing import AsyncIterator, List, Optional - -from sqlalchemy import select, update, delete -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker -from sqlalchemy.exc import OperationalError - -from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model -from modules.features.chatBot.utils.checkpointer import get_checkpointer -from modules.features.chatBot.utils.toolRegistry import get_registry -from modules.features.chatBot.utils import permissions -from modules.features.chatBot.subChatbotDatabase import UserThreadMapping -from modules.datamodels.datamodelChatbot import ( - MessageItem, - ChatMessageResponse, - ThreadSummary, - ThreadDetail, -) -from modules.datamodels.datamodelUam import User - -from langchain_core.messages import HumanMessage, AIMessage, BaseMessage -from langgraph.graph import StateGraph, MessagesState, START, END -from modules.shared.configuration import APP_CONFIG - -logger = logging.getLogger(__name__) - - -_closeCheckpointerCallable = None # set when start() initializes checkpointer -_engine = None -_SessionLocal = None - - -def _make_sqlalchemy_db_url() -> str: - from urllib.parse import quote_plus - - host = APP_CONFIG.get("SQLALCHEMY_DB_HOST", "localhost") - port = APP_CONFIG.get("SQLALCHEMY_DB_PORT", "5432") - db = APP_CONFIG.get("SQLALCHEMY_DB_DATABASE", "project_gateway") - user = APP_CONFIG.get("SQLALCHEMY_DB_USER", "postgres") - pwd = quote_plus(APP_CONFIG.get("SQLALCHEMY_DB_PASSWORD_SECRET", "")) - if sys.platform == "win32": - return f"postgresql+asyncpg://{user}:{pwd}@{host}:{port}/{db}" - return f"postgresql+psycopg://{user}:{pwd}@{host}:{port}/{db}" - - -def _create_engine_with_pool() -> tuple: - """Create async SQLAlchemy engine and sessionmaker with resilient pool settings.""" - db_url = _make_sqlalchemy_db_url() - - # Pool tuning with sensible defaults; overridable via config - pool_size = int(APP_CONFIG.get("SQLALCHEMY_POOL_SIZE", 5)) - max_overflow = int(APP_CONFIG.get("SQLALCHEMY_MAX_OVERFLOW", 10)) - pool_recycle = int(APP_CONFIG.get("SQLALCHEMY_POOL_RECYCLE_SECONDS", 300)) - pool_timeout = int(APP_CONFIG.get("SQLALCHEMY_POOL_TIMEOUT_SECONDS", 30)) - connect_timeout = int(APP_CONFIG.get("SQLALCHEMY_CONNECT_TIMEOUT_SECONDS", 10)) - - engine = create_async_engine( - db_url, - pool_pre_ping=True, - pool_size=pool_size, - max_overflow=max_overflow, - pool_recycle=pool_recycle, - pool_timeout=pool_timeout, - echo=False, - connect_args={ - # asyncpg understands timeout; psycopg ignores unknown args safely - "timeout": connect_timeout, - }, - ) - session_local = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) - return engine, session_local - - -async def start() -> None: - """Initialize ChatBot feature at application startup. - - - Creates tables if needed - - Syncs tool registry to database - - Initializes LangGraph checkpointer (except in dev) - """ - global _engine, _SessionLocal - - from modules.features.chatBot.subChatbotDatabase import init_models as _initModels - from modules.features.chatBot.subChatbotDatabase import ( - sync_tools_from_registry as _syncToolsFromRegistry, - ) - - # Ensure Windows uses SelectorEventLoop for better DB driver compatibility - if sys.platform == "win32": - return - - try: - if _engine is None: - _engine, _SessionLocal = _create_engine_with_pool() - - # Ensure DB schema exists with retry (handles transient startup issues) - await _initModelsWithRetry(_engine, _initModels) - - # Sync tools into DB - async with _SessionLocal() as session: - await _syncToolsFromRegistry(session) - await session.commit() - logger.info("ChatBot tools synced from registry to database") - except Exception as exc: - logger.error( - f"ChatBot startup failed: {type(exc).__name__}: {str(exc)}", - exc_info=True, - ) - # Intentionally swallow to avoid aborting app startup - return - - # Initialize LangGraph checkpointer (skip in dev) - global _closeCheckpointerCallable - isDev = str(APP_CONFIG.get("APP_ENV_LABEL")).lower() in ("dev", "development") - if not isDev: - try: - from modules.features.chatBot.utils.checkpointer import ( - initialize_checkpointer as _initializeCheckpointer, - close_checkpointer as _closeCheckpointer, - ) - - await _initializeCheckpointer() - _closeCheckpointerCallable = _closeCheckpointer - logger.info("LangGraph checkpointer initialized successfully (ChatBot)") - except Exception as e: - logger.error( - f"Failed to initialize LangGraph checkpointer (ChatBot): {str(e)}" - ) - _closeCheckpointerCallable = None - else: - _closeCheckpointerCallable = None - logger.info("LangGraph checkpointer disabled in dev environment (ChatBot)") - - -async def stop() -> None: - """Shutdown hook for ChatBot feature (closes checkpointer if initialized).""" - global _closeCheckpointerCallable - try: - if callable(_closeCheckpointerCallable): - try: - await _closeCheckpointerCallable() - finally: - _closeCheckpointerCallable = None - # Dispose engine if created - global _engine - if _engine is not None: - try: - await _engine.dispose() - finally: - _engine = None - except Exception as exc: - logger.warning( - f"ChatBot shutdown encountered an error: {type(exc).__name__}: {str(exc)}", - exc_info=True, - ) - - -async def _initModelsWithRetry(engine, initModelsCallable, *, maxRetries: int = 5, baseDelaySeconds: float = 0.5) -> None: - """Initialize DB models with exponential backoff to avoid failing app startup on transient DB issues.""" - attempt = 0 - while True: - try: - await initModelsCallable(engine) - return - except Exception as exc: - attempt += 1 - if attempt > maxRetries: - logger.error( - f"Failed to initialize chatbot DB models after {maxRetries} attempts: {type(exc).__name__}: {str(exc)}", - exc_info=True, - ) - # Re-raise to let caller handle (feature init may choose to continue) - raise - - # For transient connection issues, dispose and recreate the engine before retrying - transient = ( - isinstance(exc, OperationalError) - or "ConnectionDoesNotExistError" in type(exc).__name__ - or "ConnectionResetError" in type(exc).__name__ - or "WinError 64" in str(exc) - ) - if transient: - try: - global _engine, _SessionLocal - if _engine is not None: - await _engine.dispose() - _engine, _SessionLocal = _create_engine_with_pool() - engine = _engine - logger.warning("Recreated async DB engine after transient connection error during init") - except Exception as recreate_exc: - logger.warning( - f"Failed to recreate engine after transient error: {type(recreate_exc).__name__}: {str(recreate_exc)}", - exc_info=True, - ) - delay = baseDelaySeconds * (2 ** (attempt - 1)) - logger.warning( - f"DB init failed (attempt {attempt}/{maxRetries}): {type(exc).__name__}: {str(exc)}; retrying in {delay:.1f}s" - ) - await asyncio.sleep(delay) - -async def get_all_threads_for_user( - *, - user: User, - session: AsyncSession, -) -> List[ThreadSummary]: - """Get all chat threads for a user. - - Args: - user: The current user. - session: The database session for querying. - - Returns: - List of ThreadSummary objects sorted by date_updated (newest first). - Returns empty list if no threads found. - """ - logger.info(f"Fetching all threads for user {user.id}") - - # Query all threads for this user, ordered by date_updated descending - stmt = ( - select(UserThreadMapping) - .where(UserThreadMapping.user_id == user.id) - .order_by(UserThreadMapping.date_updated.desc()) - ) - result = await session.execute(stmt) - thread_mappings = result.scalars().all() - - # Convert to ThreadSummary objects - threads = [] - for mapping in thread_mappings: - thread_summary = ThreadSummary( - thread_id=mapping.thread_id, - thread_name=mapping.thread_name, - date_created=mapping.date_created.timestamp(), - date_updated=mapping.date_updated.timestamp(), - ) - threads.append(thread_summary) - - logger.info(f"Found {len(threads)} threads for user {user.id}") - return threads - - -async def save_thread_for_user( - *, - thread_id: str, - user: User, - session: AsyncSession, - thread_name: str = "New Chat", -) -> None: - """Save a new chat thread mapping for the user. - - Args: - thread_id: The unique identifier for the chat thread. - user: The current user. - session: The database session for saving. - thread_name: The name of the chat thread. Defaults to "New Chat". - """ - logger.info(f"Saving new thread {thread_id} for user {user.id}") - - # Create new mapping entry - new_mapping = UserThreadMapping( - user_id=user.id, - thread_id=thread_id, - thread_name=thread_name, - ) - - session.add(new_mapping) - await session.commit() - - logger.info(f"Successfully saved thread {thread_id} for user {user.id}") - - -async def get_or_create_thread_for_user( - *, - thread_id: Optional[str], - user: User, - session: AsyncSession, - thread_name: str = "New Chat", - refresh_date_updated: bool = False, -) -> str: - """Get an existing thread or create a new one for the user. - - If thread_id is provided, verifies it exists and belongs to the user. - If thread_id is None, generates a new thread_id and saves it. - - Args: - thread_id: Optional thread identifier. If None, creates a new thread. - user: The current user. - session: The database session for querying/saving. - thread_name: The name for the thread if creating new. Defaults to "New Chat". - refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False. - - Returns: - The thread_id to use (either the provided one or newly created). - - Raises: - PermissionError: If the thread does not belong to the user. - ValueError: If the provided thread_id does not exist. - """ - if thread_id: - # If the user provided a thread_id, verify it exists and belongs to them - await assure_thread_exists_and_belongs_to_user( - thread_id=thread_id, user=user, session=session - ) - logger.info(f"Using existing thread {thread_id} for user {user.id}") - - # Refresh date_updated if requested - if refresh_date_updated: - await refresh_thread_date_updated( - thread_id=thread_id, user=user, session=session - ) - - return thread_id - else: - # Generate new thread_id if the user did not provide a thread_id - import uuid - - new_thread_id = f"thread_{uuid.uuid4()}" - await save_thread_for_user( - thread_id=new_thread_id, - user=user, - session=session, - thread_name=thread_name, - ) - logger.info(f"Created new thread {new_thread_id} for user {user.id}") - return new_thread_id - - -async def assure_thread_exists_and_belongs_to_user( - *, - thread_id: str, - user: User, - session: AsyncSession, -) -> None: - """Ensure that the given thread ID exists and belongs to the specified user. - - Args: - thread_id: The unique identifier for the chat thread. - user: The current user. - session: The database session for querying. - Raises: - PermissionError: If the thread does not belong to the user. - ValueError: If the thread does not exist. - """ - # Query the database for the thread mapping - stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id) - result = await session.execute(stmt) - thread_mapping = result.scalar_one_or_none() - - # Check if thread exists - if thread_mapping is None: - logger.warning(f"Thread {thread_id} does not exist") - raise ValueError(f"Thread {thread_id} does not exist") - - # Check if thread belongs to the user - if thread_mapping.user_id != user.id: - logger.warning( - f"User {user.id} attempted to access thread {thread_id} " - f"belonging to user {thread_mapping.user_id}" - ) - raise PermissionError( - f"You do not have permission to access thread {thread_id}" - ) - - logger.info(f"Thread {thread_id} verified for user {user.id}") - - -async def update_thread_name( - *, - thread_id: str, - user: User, - new_thread_name: str, - session: AsyncSession, -) -> None: - """Update the name of an existing chat thread. - - This function performs security checks by including both threadId and userId - in the WHERE clause of the UPDATE query, ensuring users can only update - threads that belong to them. No separate permission check is needed. - - Args: - thread_id: The unique identifier for the chat thread. - user: The current user. - new_thread_name: The new name to set for the thread. - session: The database session for updating. - - Raises: - ValueError: If the thread does not exist or does not belong to the user. - """ - logger.info( - f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}" - ) - - # Update the thread name and date_updated - # Security check: WHERE clause includes both thread_id AND user_id - stmt = ( - update(UserThreadMapping) - .where( - UserThreadMapping.thread_id == thread_id, - UserThreadMapping.user_id == user.id, - ) - .values(thread_name=new_thread_name, date_updated=datetime.now(timezone.utc)) - ) - result = await session.execute(stmt) - await session.commit() - - # Check if any rows were affected - if result.rowcount == 0: - logger.warning( - f"Failed to update thread {thread_id} for user {user.id} - " - "thread does not exist or does not belong to user" - ) - raise ValueError( - f"Thread {thread_id} does not exist or you do not have permission to access it" - ) - - logger.info(f"Successfully updated thread {thread_id} name for user {user.id}") - - -async def refresh_thread_date_updated( - *, - thread_id: str, - user: User, - session: AsyncSession, -) -> None: - """Refresh the date_updated timestamp for an existing chat thread. - - This function performs security checks by including both threadId and userId - in the WHERE clause of the UPDATE query, ensuring users can only update - threads that belong to them. No separate permission check is needed. - - Args: - thread_id: The unique identifier for the chat thread. - user: The current user. - session: The database session for updating. - - Raises: - ValueError: If the thread does not exist or does not belong to the user. - """ - logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}") - - # Update the date_updated timestamp - # Security check: WHERE clause includes both thread_id AND user_id - stmt = ( - update(UserThreadMapping) - .where( - UserThreadMapping.thread_id == thread_id, - UserThreadMapping.user_id == user.id, - ) - .values(date_updated=datetime.now(timezone.utc)) - ) - result = await session.execute(stmt) - await session.commit() - - # Check if any rows were affected - if result.rowcount == 0: - logger.warning( - f"Failed to refresh thread {thread_id} for user {user.id} - " - "thread does not exist or does not belong to user" - ) - raise ValueError( - f"Thread {thread_id} does not exist or you do not have permission to access it" - ) - - logger.info( - f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}" - ) - - -async def post_message( - *, - thread_id: str, - message: str, - user: User, - tool_ids: List[str], -) -> ChatMessageResponse: - """Post a chat message to the chatbot and return the response. - - Args: - thread_id: The unique identifier for the chat thread. - message: The content of the chat message. - user: The current user. - tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools. - - Returns: - The response containing the full chat message history and thread ID. - """ - logger.info( - f"User {user.id} posted message to thread {thread_id} with {len(tool_ids)} tools" - ) - - model_name = permissions.get_chatbot_model(user_id=user.id) - system_prompt = permissions.get_system_prompt(user_id=user.id) - - # Get tools from registry (empty list if no tools) - registry = get_registry() - tools = registry.get_tool_instances(tool_ids=tool_ids) - - # Get model and checkpointer - model = get_langchain_model(model_name=model_name) - checkpointer = get_checkpointer() - - # Get context window size from config - context_window_size = int( - APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) - ) - - # Create chatbot instance - chatbot = await Chatbot.create( - model=model, - memory=checkpointer, - system_prompt=system_prompt, - tools=tools, - context_window_size=context_window_size, - ) - - # Send message to chatbot - response = await chatbot.chat(message=message, chat_id=thread_id) - - # Parse the response to the correct format - messages = [] - for msg in response: - # Determine the role of the message - if isinstance(msg, HumanMessage): - role = "user" - elif isinstance(msg, AIMessage): - role = "assistant" - else: - continue # Skip any other message types - - # Skip messages that are structured content, such as tool calls - if not isinstance(msg.content, str): - continue - - # Append message to chat history - item = MessageItem( - role=role, - content=msg.content.strip(), - ) - messages.append(item) - - return ChatMessageResponse(thread_id=thread_id, messages=messages) - - -async def post_message_stream( - *, - thread_id: str, - message: str, - user: User, - tool_ids: List[str], -) -> AsyncIterator[str]: - """Post a chat message to the chatbot and stream progress updates (SSE). - - Args: - thread_id: The unique identifier for the chat thread. - message: The content of the chat message. - user: The current user. - tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools. - - Yields: - Server-Sent Events formatted strings containing status updates and final response. - """ - logger.info( - f"User {user.id} streaming message to thread {thread_id} with {len(tool_ids)} tools" - ) - - try: - model_name = permissions.get_chatbot_model(user_id=user.id) - system_prompt = permissions.get_system_prompt(user_id=user.id) - - # Get tools from registry (empty list if no tools) - registry = get_registry() - tools = registry.get_tool_instances(tool_ids=tool_ids) - - # Get model and checkpointer - model = get_langchain_model(model_name=model_name) - checkpointer = get_checkpointer() - - # Get context window size from config - context_window_size = int( - APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) - ) - - # Create chatbot instance - chatbot = await Chatbot.create( - model=model, - memory=checkpointer, - system_prompt=system_prompt, - tools=tools, - context_window_size=context_window_size, - ) - - # Stream events from chatbot - async for event in chatbot.stream_events(message=message, chat_id=thread_id): - etype = event.get("type") - - # Forward status updates - if etype == "status": - yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n" - continue - - # Forward final response - if etype == "final": - response_from_event = event.get("response") or {} - - # Use the chat history from the final event (already normalized by stream_events) - chat_history_payload = response_from_event.get("chat_history", []) - if isinstance(chat_history_payload, list): - # Convert to MessageItem format - items: List[MessageItem] = [] - for it in chat_history_payload: - role = it.get("role") - content = it.get("content", "") - if role in ("user", "assistant") and content: - items.append( - MessageItem( - role=role, - content=content, - ) - ) - - response = ChatMessageResponse(thread_id=thread_id, messages=items) - # Yield the final response and exit - yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" - return - else: - # Unexpected payload format - log warning and return empty history - logger.warning( - f"Unexpected chat_history format in final event: {type(chat_history_payload)}" - ) - response = ChatMessageResponse(thread_id=thread_id, messages=[]) - yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" - return - - # Forward error events - if etype == "error": - yield f"data: {json.dumps(event)}\n\n" - return - - except Exception as e: - error_msg = f"{type(e).__name__}: {str(e) or 'No error message provided'}" - logger.error(f"Error in streaming chat: {error_msg}", exc_info=True) - yield ( - "data: " - + json.dumps( - { - "type": "error", - "message": f"An error occurred while processing your request: {error_msg}", - } - ) - + "\n\n" - ) - - -# Module-level singleton for minimal app used to read thread state -_MINIMAL_APP = None - - -def _build_minimal_app(*, checkpointer): - """Build a minimal LangGraph app for reading thread state. - - This creates a valid graph with a no-op node that we never actually run. - LangGraph requires a valid graph structure (with edges from START) to compile, - even though we only use it to call aget_state() to read from the checkpointer. - - Args: - checkpointer: The checkpointer to attach to the graph. - - Returns: - A compiled StateGraph that can be used to read thread state. - """ - graph = StateGraph(MessagesState) - - # No-op node that returns the state unchanged - def noop(state: dict) -> dict: - return state - - graph.add_node("noop", noop) - graph.add_edge(START, "noop") - graph.add_edge("noop", END) - - return graph.compile(checkpointer=checkpointer) - - -def _get_minimal_app(): - """Get the module-level singleton minimal app. - - Returns: - The cached minimal app, building it on first access. - """ - global _MINIMAL_APP - if _MINIMAL_APP is None: - _MINIMAL_APP = _build_minimal_app(checkpointer=get_checkpointer()) - return _MINIMAL_APP - - -async def get_thread_messages_from_langgraph( - *, - thread_id: str, -) -> List[dict]: - """Retrieve and format messages from LangGraph checkpointer. - - Args: - thread_id: The unique identifier for the chat thread. - - Returns: - List of message dicts with role and content. - """ - ROLE_MAP = {"human": "user", "ai": "assistant"} - - # Get the minimal app (singleton, built once) - app = _get_minimal_app() - - cfg = {"configurable": {"thread_id": thread_id}} - state = await app.aget_state(cfg) - - messages = [] - for msg in state.values.get("messages", []): - # Skip system and tool messages - only include user and assistant - if msg.type not in ["human", "ai"]: - continue - - # Skip messages with non-string content (e.g., tool calls) - if not isinstance(msg.content, str): - continue - - messages.append( - { - "role": ROLE_MAP.get(msg.type, msg.type), - "content": msg.content, - } - ) - - return messages - - -async def get_thread_detail_for_user( - *, - thread_id: str, - user: User, - session: AsyncSession, -) -> ThreadDetail: - """Get detailed thread information with message history from LangGraph. - - Args: - thread_id: The unique identifier for the chat thread. - user: The current user. - session: The database session for querying. - - Returns: - ThreadDetail object with thread metadata and message history. - - Raises: - PermissionError: If the thread does not belong to the user. - ValueError: If the thread does not exist. - """ - logger.info(f"Getting thread detail for thread {thread_id} for user {user.id}") - - # Verify thread exists and belongs to user - await assure_thread_exists_and_belongs_to_user( - thread_id=thread_id, user=user, session=session - ) - - # Get thread metadata from database - stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id) - result = await session.execute(stmt) - thread_mapping = result.scalar_one() - - # Get messages from LangGraph checkpointer (optimized - no full chatbot needed) - message_dicts = await get_thread_messages_from_langgraph(thread_id=thread_id) - - # Convert to MessageItem objects - messages = [MessageItem(**m) for m in message_dicts] - - logger.info( - f"Retrieved thread {thread_id} with {len(messages)} messages for user {user.id}" - ) - - # Return ThreadDetail - return ThreadDetail( - thread_id=thread_id, - date_created=thread_mapping.date_created.timestamp(), - date_updated=thread_mapping.date_updated.timestamp(), - messages=messages, - ) - - -async def delete_thread_for_user( - *, - thread_id: str, - user: User, - session: AsyncSession, -) -> None: - """Delete a chat thread for a user from both LangGraph and the database. - - Args: - thread_id: The unique identifier for the chat thread. - user: The current user. - session: The database session for deleting. - - Raises: - PermissionError: If the thread does not belong to the user. - ValueError: If the thread does not exist. - """ - logger.info(f"Deleting thread {thread_id} for user {user.id}") - - # Verify thread exists and belongs to user - await assure_thread_exists_and_belongs_to_user( - thread_id=thread_id, user=user, session=session - ) - - # Delete from LangGraph checkpointer (optimized - no app/tools/model needed) - checkpointer = get_checkpointer() - try: - await checkpointer.adelete_thread(thread_id) - logger.info(f"Deleted thread {thread_id} from LangGraph checkpointer") - except Exception as e: - logger.error( - f"Failed to delete thread {thread_id} from LangGraph: {type(e).__name__}: {str(e)}", - exc_info=True, - ) - raise ValueError( - f"Failed to delete thread from LangGraph: {type(e).__name__}: {str(e)}" - ) - - # Delete from database - stmt = delete(UserThreadMapping).where( - UserThreadMapping.thread_id == thread_id, - UserThreadMapping.user_id == user.id, - ) - result = await session.execute(stmt) - await session.commit() - - # Check if any rows were deleted - if result.rowcount == 0: - logger.warning( - f"Failed to delete thread {thread_id} from database for user {user.id} - " - "thread does not exist or does not belong to user" - ) - raise ValueError( - f"Thread {thread_id} does not exist or you do not have permission to access it" - ) - - logger.info(f"Successfully deleted thread {thread_id} for user {user.id}") - - -# Tool Management Functions - - -async def get_all_tools(*, session: AsyncSession) -> List[dict]: - """Get all tools from the database. - - Args: - session: The database session for querying. - - Returns: - List of tool dictionaries with all tool information. - """ - from modules.features.chatBot.subChatbotDatabase import Tool - - logger.info("Fetching all tools from database") - - stmt = select(Tool).order_by(Tool.category, Tool.name) - result = await session.execute(stmt) - tools = result.scalars().all() - - tool_list = [] - for tool in tools: - tool_dict = { - "id": str(tool.id), - "tool_id": tool.tool_id, - "name": tool.name, - "label": tool.label, - "category": tool.category, - "description": tool.description, - "is_active": tool.is_active, - "date_created": tool.date_created.timestamp(), - "date_updated": tool.date_updated.timestamp(), - } - tool_list.append(tool_dict) - - logger.info(f"Retrieved {len(tool_list)} tools from database") - return tool_list - - -async def grant_tool_to_user( - *, user_id: str, tool_id: str, session: AsyncSession -) -> None: - """Grant a tool to a user. - - Args: - user_id: The user ID to grant the tool to. - tool_id: The tool UUID from the tools table. - session: The database session for querying/updating. - - Raises: - ValueError: If the tool doesn't exist, is not active, or user already has the tool. - """ - from modules.features.chatBot.subChatbotDatabase import Tool, UserToolMapping - import uuid - - logger.info(f"Granting tool {tool_id} to user {user_id}") - - # Convert tool_id string to UUID - try: - tool_uuid = uuid.UUID(tool_id) - except ValueError: - raise ValueError(f"Invalid tool ID format: {tool_id}") - - # Check if tool exists and is active - stmt = select(Tool).where(Tool.id == tool_uuid) - result = await session.execute(stmt) - tool = result.scalar_one_or_none() - - if tool is None: - raise ValueError(f"Tool with ID {tool_id} does not exist") - - if not tool.is_active: - raise ValueError( - f"Cannot grant inactive tool '{tool.label}' (tool_id: {tool.tool_id}). " - f"Please activate the tool first before granting it to users." - ) - - # Check if user already has this tool - stmt = select(UserToolMapping).where( - UserToolMapping.user_id == user_id, UserToolMapping.tool_id == tool_uuid - ) - result = await session.execute(stmt) - existing_mapping = result.scalar_one_or_none() - - if existing_mapping is not None: - raise ValueError( - f"User {user_id} already has access to tool '{tool.label}' (tool_id: {tool.tool_id})" - ) - - # Create new mapping - new_mapping = UserToolMapping( - user_id=user_id, - tool_id=tool_uuid, - is_active=True, - ) - - session.add(new_mapping) - await session.commit() - - logger.info(f"Successfully granted tool {tool_id} ({tool.label}) to user {user_id}") - - -async def revoke_tool_from_user( - *, user_id: str, tool_id: str, session: AsyncSession -) -> None: - """Revoke a tool from a user by deleting the mapping. - - Args: - user_id: The user ID to revoke the tool from. - tool_id: The tool UUID from the tools table. - session: The database session for deleting. - - Raises: - ValueError: If the mapping doesn't exist. - """ - from modules.features.chatBot.subChatbotDatabase import UserToolMapping - import uuid - - logger.info(f"Revoking tool {tool_id} from user {user_id}") - - # Convert tool_id string to UUID - try: - tool_uuid = uuid.UUID(tool_id) - except ValueError: - raise ValueError(f"Invalid tool ID format: {tool_id}") - - # Delete the mapping - stmt = delete(UserToolMapping).where( - UserToolMapping.user_id == user_id, UserToolMapping.tool_id == tool_uuid - ) - result = await session.execute(stmt) - await session.commit() - - # Check if any rows were deleted - if result.rowcount == 0: - raise ValueError( - f"User {user_id} does not have access to tool {tool_id}, or the mapping does not exist" - ) - - logger.info(f"Successfully revoked tool {tool_id} from user {user_id}") - - -async def update_tool( - *, - tool_id: str, - label: Optional[str], - description: Optional[str], - session: AsyncSession, -) -> List[str]: - """Update a tool's label and/or description. - - Args: - tool_id: The tool UUID to update. - label: Optional new label for the tool. - description: Optional new description for the tool. - session: The database session for updating. - - Returns: - List of updated field names. - - Raises: - ValueError: If the tool doesn't exist or no fields provided to update. - """ - from modules.features.chatBot.subChatbotDatabase import Tool - import uuid - - logger.info(f"Updating tool {tool_id}") - - # Validate that at least one field is provided - if label is None and description is None: - raise ValueError("At least one field (label or description) must be provided") - - # Convert tool_id string to UUID - try: - tool_uuid = uuid.UUID(tool_id) - except ValueError: - raise ValueError(f"Invalid tool ID format: {tool_id}") - - # Check if tool exists - stmt = select(Tool).where(Tool.id == tool_uuid) - result = await session.execute(stmt) - tool = result.scalar_one_or_none() - - if tool is None: - raise ValueError(f"Tool with ID {tool_id} does not exist") - - # Build update values - update_values = {"date_updated": datetime.now(timezone.utc)} - updated_fields = [] - - if label is not None: - update_values["label"] = label - updated_fields.append("label") - - if description is not None: - update_values["description"] = description - updated_fields.append("description") - - # Update the tool - stmt = update(Tool).where(Tool.id == tool_uuid).values(**update_values) - await session.execute(stmt) - await session.commit() - - logger.info(f"Successfully updated tool {tool_id}, fields: {updated_fields}") - return updated_fields - - -async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dict]: - """Get all tools granted to a specific user. - - Args: - user_id: The user ID to get tools for. - session: The database session for querying. - - Returns: - List of tool dictionaries with all tool information. - """ - from modules.features.chatBot.subChatbotDatabase import Tool, UserToolMapping - - logger.info(f"Fetching tools for user {user_id}") - - # Query tools that are granted to the user - # Join UserToolMapping with Tool table - # Filter by user_id and active status - stmt = ( - select(Tool) - .join(UserToolMapping, Tool.id == UserToolMapping.tool_id) - .where( - UserToolMapping.user_id == user_id, - UserToolMapping.is_active == True, - Tool.is_active == True, - ) - .order_by(Tool.category, Tool.name) - ) - result = await session.execute(stmt) - tools = result.scalars().all() - - tool_list = [] - for tool in tools: - tool_dict = { - "id": str(tool.id), - "tool_id": tool.tool_id, - "name": tool.name, - "label": tool.label, - "category": tool.category, - "description": tool.description, - "is_active": tool.is_active, - "date_created": tool.date_created.timestamp(), - "date_updated": tool.date_updated.timestamp(), - } - tool_list.append(tool_dict) - - logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}") - return tool_list - - -async def validate_and_get_tools_for_request( - *, - user_id: str, - requested_tool_ids: Optional[List[str]], - session: AsyncSession, -) -> List[str]: - """Validate and get tool IDs for a chat request. - - This function validates that the user has access to the requested tools. - If no tools are requested (None), it returns all tools the user has access to. - If an empty list is provided, it returns an empty list (no tools). - - Args: - user_id: The user ID making the request. - requested_tool_ids: Optional list of tool UUIDs (id field) requested by the user. - - None: Use all tools the user has access to - - []: Use no tools at all - - ["uuid1", "uuid2"]: Use only the specified tools - session: The database session for querying. - - Returns: - List of validated tool IDs (tool_id field, not UUID) that the user can use. - - Raises: - PermissionError: If the user requests tools they don't have access to. - ValueError: If the user has no tools available when trying to use all tools. - """ - from modules.features.chatBot.subChatbotDatabase import Tool, UserToolMapping - import uuid - - logger.info(f"Validating tools for user {user_id}") - - # If empty list is explicitly provided, return empty list (no tools) - if requested_tool_ids is not None and len(requested_tool_ids) == 0: - logger.info( - f"Empty tool list requested, chatbot will run without tools for user {user_id}" - ) - return [] - - # Get all tools the user has access to - stmt = ( - select(Tool) - .join(UserToolMapping, Tool.id == UserToolMapping.tool_id) - .where( - UserToolMapping.user_id == user_id, - UserToolMapping.is_active == True, - Tool.is_active == True, - ) - ) - result = await session.execute(stmt) - user_tools = result.scalars().all() - - # Create mappings for both UUID and tool_id - user_tool_ids_by_uuid = {str(tool.id): tool.tool_id for tool in user_tools} - user_tool_ids = set(user_tool_ids_by_uuid.values()) - - if not user_tool_ids: - logger.warning(f"User {user_id} has no tools available") - raise ValueError("User does not have access to any chatbot tools") - - # If no specific tools requested (None), return all user's tools - if requested_tool_ids is None: - logger.info( - f"No specific tools requested, returning all {len(user_tool_ids)} tools for user {user_id}" - ) - return list(user_tool_ids) - - # Convert requested UUIDs to tool_ids and validate access - requested_tool_ids_result = [] - unauthorized_uuids = [] - - for requested_uuid in requested_tool_ids: - if requested_uuid in user_tool_ids_by_uuid: - # User has access to this tool - requested_tool_ids_result.append(user_tool_ids_by_uuid[requested_uuid]) - else: - # User doesn't have access to this tool - unauthorized_uuids.append(requested_uuid) - - if unauthorized_uuids: - logger.warning( - f"User {user_id} requested unauthorized tool UUIDs: {unauthorized_uuids}" - ) - raise PermissionError( - f"You do not have access to the following tools: {', '.join(unauthorized_uuids)}" - ) - - logger.info( - f"Validated {len(requested_tool_ids_result)} requested tools for user {user_id}" - ) - return requested_tool_ids_result diff --git a/modules/features/chatBot/subChatbotDatabase.py b/modules/features/chatBot/subChatbotDatabase.py deleted file mode 100644 index 1dc4ebe6..00000000 --- a/modules/features/chatBot/subChatbotDatabase.py +++ /dev/null @@ -1,197 +0,0 @@ -from typing import AsyncIterator -import uuid -from datetime import datetime, timezone - -from fastapi import Request -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint - - -class Base(DeclarativeBase): - pass - - -# Tools Table -class Tool(Base): - """Available chatbot tools. - - Stores information about all available tools that can be assigned to users. - Each tool has a unique tool_id that corresponds to the registry tool_id. - """ - - __tablename__ = "tools" - id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) - tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) - name: Mapped[str] = mapped_column(String(255), nullable=False) - label: Mapped[str] = mapped_column(String(255), nullable=False) - category: Mapped[str] = mapped_column(String(50), nullable=False) - description: Mapped[str] = mapped_column(String(1000), nullable=False) - is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - date_created: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - default=lambda: datetime.now(timezone.utc), - ) - date_updated: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - default=lambda: datetime.now(timezone.utc), - ) - - -# User-Tool Mapping Table -class UserToolMapping(Base): - """Mapping of users to their available tools. - - Many-to-many relationship between users and tools. - - One user can have multiple tools - - One tool can be assigned to multiple users - - The combination of user_id and tool_id is unique. - """ - - __tablename__ = "user_tools" - __table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),) - - id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) - user_id: Mapped[str] = mapped_column(String(255), nullable=False) - tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False) - is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - date_granted: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - default=lambda: datetime.now(timezone.utc), - ) - date_updated: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - default=lambda: datetime.now(timezone.utc), - ) - - -# User Thread Mapping Table -class UserThreadMapping(Base): - """Mapping of users to their chat threads. - - Used to keep track of which user owns which chat thread. - Also stores meta data like thread name. - - 1:N relationship between user and thread. A thread belongs to exactly one user. - A user can have multiple threads. - Thread_id is unique in the table. - """ - - __tablename__ = "user_threads" - id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) - user_id: Mapped[str] = mapped_column(String(255), nullable=False) - thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) - thread_name: Mapped[str] = mapped_column(String(255), nullable=False) - date_created: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - default=lambda: datetime.now(timezone.utc), - ) - date_updated: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - default=lambda: datetime.now(timezone.utc), - ) - - -# Dependency that pulls the sessionmaker off app.state -# This is set in app.py on startup in @asynccontextmanager -# TODO: If we use SQLAlchemy in other places, we can move this to a shared module -async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]: - SessionLocal: async_sessionmaker[AsyncSession] = ( - request.app.state.checkpoint_sessionmaker - ) - async with SessionLocal() as session: - yield session - - -# Optional helper to init tables at startup (demo only) -async def init_models(engine) -> None: - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - -async def sync_tools_from_registry(session: AsyncSession) -> None: - """Sync tools from tool registry to database. - - This function: - - Adds new tools from the registry to the database - - Updates existing tools with current registry information - - Marks tools not present in the registry as inactive - - Should be called on application startup after database initialization. - - Args: - session: Active database session - """ - import logging - from sqlalchemy import select - - from modules.features.chatBot.utils.toolRegistry import get_registry - - logger = logging.getLogger(__name__) - logger.info("Syncing tools from registry to database...") - - # Get all tools from the registry - registry = get_registry() - registry_tools = registry.get_all_tools() - - # Create a set of tool_ids from the registry - registry_tool_ids = {tool.tool_id for tool in registry_tools} - - logger.info(f"Found {len(registry_tools)} tools in registry") - - # Get all existing tools from the database - result = await session.execute(select(Tool)) - db_tools = result.scalars().all() - db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools} - - logger.info(f"Found {len(db_tools)} tools in database") - - # Track changes - added_count = 0 - updated_count = 0 - deactivated_count = 0 - - # Sync tools from registry to database - for registry_tool in registry_tools: - if registry_tool.tool_id in db_tools_by_tool_id: - # Tool exists - update it - # Preserve label and description (user-editable fields) - db_tool = db_tools_by_tool_id[registry_tool.tool_id] - db_tool.name = registry_tool.name - db_tool.category = registry_tool.category - db_tool.is_active = True - db_tool.date_updated = datetime.now(timezone.utc) - updated_count += 1 - logger.debug(f"Updated tool: {registry_tool.tool_id}") - else: - # Tool doesn't exist - create it - new_tool = Tool( - tool_id=registry_tool.tool_id, - name=registry_tool.name, - label=registry_tool.tool_id, # Use tool_id as label per spec - category=registry_tool.category, - description=registry_tool.description or "", - is_active=True, - ) - session.add(new_tool) - added_count += 1 - logger.debug(f"Added new tool: {registry_tool.tool_id}") - - # Mark tools not in registry as inactive - for db_tool in db_tools: - if db_tool.tool_id not in registry_tool_ids and db_tool.is_active: - db_tool.is_active = False - db_tool.date_updated = datetime.now(timezone.utc) - deactivated_count += 1 - logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}") - - logger.info( - f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated" - ) diff --git a/modules/features/chatBot/utils/checkpointer.py b/modules/features/chatBot/utils/checkpointer.py deleted file mode 100644 index a51e7455..00000000 --- a/modules/features/chatBot/utils/checkpointer.py +++ /dev/null @@ -1,106 +0,0 @@ -"""PostgreSQL checkpointer utilities for LangGraph memory.""" - -import sys -import asyncio -import logging -from typing import Optional - -# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply) -if sys.platform == 'win32': - try: - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - except RuntimeError: - pass # Already set - -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from psycopg_pool import AsyncConnectionPool -from psycopg.rows import dict_row -from modules.shared.configuration import APP_CONFIG - -logger = logging.getLogger(__name__) - -# Global checkpointer instance -_checkpointer_instance: Optional[AsyncPostgresSaver] = None -_connection_pool: Optional[AsyncConnectionPool] = None - - -async def initialize_checkpointer() -> None: - """Initialize the PostgreSQL checkpointer for LangGraph. - - This should be called during application startup. - Creates a connection pool and PostgresSaver instance. - """ - global _checkpointer_instance, _connection_pool - - if _checkpointer_instance is not None: - logger.info("Checkpointer already initialized") - return - - try: - # Get database configuration from environment - host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost") - database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat") - user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev") - password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET") - port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432") - - # Build connection string - connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}" - - # Create async connection pool - _connection_pool = AsyncConnectionPool( - conninfo=connection_string, - min_size=2, - max_size=10, - kwargs={"autocommit": True, "row_factory": dict_row}, - ) - - # Initialize the connection pool - await _connection_pool.open() - - # Create AsyncPostgresSaver with the pool - _checkpointer_instance = AsyncPostgresSaver(_connection_pool) - - # Setup the checkpointer (creates tables if needed) - await _checkpointer_instance.setup() - - logger.info("PostgreSQL checkpointer initialized successfully") - - except Exception as e: - logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}") - raise - - -async def close_checkpointer() -> None: - """Close the checkpointer and connection pool. - - This should be called during application shutdown. - """ - global _checkpointer_instance, _connection_pool - - if _connection_pool is not None: - try: - await _connection_pool.close() - logger.info("PostgreSQL checkpointer connection pool closed") - except Exception as e: - logger.error(f"Error closing checkpointer connection pool: {str(e)}") - - _checkpointer_instance = None - _connection_pool = None - - -def get_checkpointer() -> AsyncPostgresSaver: - """Get the global PostgreSQL checkpointer instance. - - Returns: - The initialized AsyncPostgresSaver instance - - Raises: - RuntimeError: If checkpointer is not initialized - """ - if _checkpointer_instance is None: - raise RuntimeError( - "PostgreSQL checkpointer not initialized. " - "Call initialize_checkpointer() during application startup." - ) - return _checkpointer_instance diff --git a/modules/features/chatBot/utils/permissions.py b/modules/features/chatBot/utils/permissions.py deleted file mode 100644 index d2fb4d65..00000000 --- a/modules/features/chatBot/utils/permissions.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Mock permissions module for chatbot access control. - -This module provides mock permission functions that will be replaced -with actual database-driven permissions in the future. -""" - -from datetime import datetime - -from modules.features.chatBot.utils.toolRegistry import get_registry - - -# TODO: Replace these mock implementations with actual database queries - - -def get_chatbot_tools(*, user_id: str) -> list[str]: - """Get list of tool IDs that the chatbot can use for a given user.""" - registry = get_registry() - return registry.list_tool_ids() - - -def get_chatbot_model(*, user_id: str) -> str: - """Gets the chatbot model(s) a user is allowed to use.""" - return "claude_4_5" - - -def get_system_prompt(*, user_id: str) -> str: - """Get the system prompt for a user's chatbot session. - - This is a mock implementation that returns a generic prompt with today's date. - In production, this will query the database for user-specific or role-specific prompts. - - Args: - user_id: The unique identifier of the user - - Returns: - The system prompt string with the current date - """ - current_date = datetime.now().strftime("%Y-%m-%d") - return f"You're a smart assistant. Today is {current_date}" diff --git a/modules/features/chatBot/utils/toolRegistry.py b/modules/features/chatBot/utils/toolRegistry.py deleted file mode 100644 index 5f5d14d6..00000000 --- a/modules/features/chatBot/utils/toolRegistry.py +++ /dev/null @@ -1,305 +0,0 @@ -"""Tool registry for auto-discovering and managing chatbot tools. - -This module provides a central registry that automatically discovers all tools -in the chatbotTools directory structure and provides methods to query them. -The registry is built in-memory at startup and does not require a database. -""" - -import importlib -import inspect -import logging -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, List, Optional - -from langchain_core.tools import BaseTool - -logger = logging.getLogger(__name__) - - -@dataclass -class ToolMetadata: - """Metadata about a discovered chatbot tool. - - Attributes: - tool_id: Unique identifier (e.g., 'shared.tavily_search') - name: Function name of the tool - category: Category of the tool ('shared' or 'customer') - description: Tool description from docstring - tool_instance: The actual LangChain tool instance - module_path: Full Python module path - """ - - tool_id: str - name: str - category: str - description: str - tool_instance: BaseTool - module_path: str - - def __str__(self) -> str: - """Return a pretty-printed string representation for logging.""" - return ( - f"ToolMetadata(\n" - f" tool_id='{self.tool_id}',\n" - f" name='{self.name}',\n" - f" category='{self.category}',\n" - f" description='{self.description}',\n" - f" module_path='{self.module_path}'\n" - f")" - ) - - -class ToolRegistry: - """Central registry for all chatbot tools. - - This class discovers and catalogs all tools decorated with @tool in the - chatbotTools directory structure. Tools are automatically discovered at - initialization by scanning the filesystem. - - The registry provides methods to query tools by ID, category, or get all tools. - """ - - def __init__(self) -> None: - """Initialize an empty tool registry.""" - self._tools: Dict[str, ToolMetadata] = {} - self._initialized: bool = False - - def initialize(self) -> None: - """Discover and register all tools from the chatbotTools directory. - - This method scans both sharedTools and customerTools directories, - imports all tool*.py modules, and extracts functions decorated with @tool. - - This method is idempotent - calling it multiple times has no effect - after the first initialization. - """ - if self._initialized: - logger.debug("Tool registry already initialized, skipping") - return - - logger.info("Initializing tool registry...") - - # Get base path to chatbotTools directory - base_path = Path(__file__).parent.parent / "chatbotTools" - - if not base_path.exists(): - logger.warning(f"chatbotTools directory not found at {base_path}") - self._initialized = True - return - - # Discover tools in each category - self._discover_category( - category_path=base_path / "sharedTools", category="shared" - ) - self._discover_category( - category_path=base_path / "customerTools", category="customer" - ) - - self._initialized = True - logger.info(f"Tool registry initialized with {len(self._tools)} tools") - - def _discover_category(self, *, category_path: Path, category: str) -> None: - """Discover all tools in a specific category directory. - - Args: - category_path: Path to the category directory (sharedTools or customerTools) - category: Category name ('shared' or 'customer') - """ - if not category_path.exists(): - logger.warning(f"Category directory not found: {category_path}") - return - - logger.debug(f"Discovering tools in category: {category}") - - # Find all tool*.py files (excluding __init__.py) - tool_files = [ - f for f in category_path.glob("tool*.py") if f.name != "__init__.py" - ] - - for tool_file in tool_files: - self._import_and_register_tools( - tool_file=tool_file, category=category, category_path=category_path - ) - - logger.debug(f"Discovered {len(tool_files)} tool files in {category}") - - def _import_and_register_tools( - self, *, tool_file: Path, category: str, category_path: Path - ) -> None: - """Import a tool module and register all discovered tools. - - Args: - tool_file: Path to the tool Python file - category: Category name ('shared' or 'customer') - category_path: Path to the category directory - """ - # Construct module name - module_name = ( - f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}" - ) - - try: - # Import the module - module = importlib.import_module(module_name) - - # Find all BaseTool instances in the module - tools_found = 0 - for name, obj in inspect.getmembers(module): - if isinstance(obj, BaseTool): - self._register_tool( - tool_instance=obj, - name=name, - category=category, - module_path=module_name, - ) - tools_found += 1 - - if tools_found == 0: - logger.warning(f"No tools found in {module_name}") - else: - logger.debug(f"Loaded {tools_found} tool(s) from {module_name}") - - except ImportError as e: - logger.error( - f"Import error loading tools from {module_name}: {str(e)}. " - f"This tool will not be available." - ) - except Exception as e: - logger.error( - f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}" - ) - - def _register_tool( - self, *, tool_instance: BaseTool, name: str, category: str, module_path: str - ) -> None: - """Register a single tool in the registry. - - Args: - tool_instance: The LangChain tool instance - name: Function name of the tool - category: Category name ('shared' or 'customer') - module_path: Full Python module path - """ - tool_id = f"{category}.{name}" - - # Check for duplicate tool IDs - if tool_id in self._tools: - logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting") - - metadata = ToolMetadata( - tool_id=tool_id, - name=name, - category=category, - description=tool_instance.description or "", - tool_instance=tool_instance, - module_path=module_path, - ) - - self._tools[tool_id] = metadata - logger.debug(f"Registered tool: {tool_id}") - - def get_all_tools(self) -> List[ToolMetadata]: - """Get all registered tools. - - Returns: - List of all tool metadata objects - """ - return list(self._tools.values()) - - def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]: - """Get a specific tool by its ID. - - Args: - tool_id: The tool identifier (e.g., 'shared.tavily_search') - - Returns: - Tool metadata if found, None otherwise - """ - return self._tools.get(tool_id) - - def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]: - """Get all tools in a specific category. - - Args: - category: Category name ('shared' or 'customer') - - Returns: - List of tool metadata for the specified category - """ - return [t for t in self._tools.values() if t.category == category] - - def list_tool_ids(self) -> List[str]: - """Get a list of all registered tool IDs. - - Returns: - List of tool ID strings - """ - return list(self._tools.keys()) - - def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]: - """Get actual tool instances for a list of tool IDs. - - This is useful for filtering tools based on user permissions. - - Args: - tool_ids: List of tool IDs to retrieve - - Returns: - List of BaseTool instances for the specified IDs - """ - instances = [] - for tool_id in tool_ids: - metadata = self.get_tool(tool_id=tool_id) - if metadata: - instances.append(metadata.tool_instance) - else: - logger.warning(f"Tool ID not found in registry: {tool_id}") - return instances - - @property - def is_initialized(self) -> bool: - """Check if the registry has been initialized. - - Returns: - True if initialized, False otherwise - """ - return self._initialized - - -# Global registry instance -_registry: Optional[ToolRegistry] = None - - -def get_registry() -> ToolRegistry: - """Get the global tool registry instance. - - This function ensures the registry is initialized on first access. - Subsequent calls return the same instance. - - Returns: - The global ToolRegistry instance - """ - global _registry - - if _registry is None: - _registry = ToolRegistry() - - if not _registry.is_initialized: - _registry.initialize() - - return _registry - - -def reinitialize_registry() -> ToolRegistry: - """Force reinitialize the tool registry. - - This is useful for testing or when tools are added dynamically. - - Returns: - The reinitialized ToolRegistry instance - """ - global _registry - _registry = ToolRegistry() - _registry.initialize() - return _registry diff --git a/modules/features/featuresLifecycle.py b/modules/features/featuresLifecycle.py index 7089ef31..85650c14 100644 --- a/modules/features/featuresLifecycle.py +++ b/modules/features/featuresLifecycle.py @@ -13,14 +13,15 @@ async def start() -> None: from modules.features.syncDelta import mainSyncDelta mainSyncDelta.startSyncManager(eventUser) - # Feature ChatBot - from modules.features.chatBot.mainChatBot import start as startChatBot - await startChatBot() + # Feature ... + + return True + async def stop() -> None: """ Stop feature triggers and background managers """ - # Feature ChatBot - from modules.features.chatBot.mainChatBot import stop as stopChatBot - await stopChatBot() \ No newline at end of file + # Feature ... + + return True diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py deleted file mode 100644 index 75a2ffac..00000000 --- a/modules/routes/routeChatbot.py +++ /dev/null @@ -1,653 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.requests import Request -from fastapi.responses import StreamingResponse -import logging - -from modules.datamodels.datamodelUam import User, UserPrivilege -from modules.security.auth import getCurrentUser, limiter - -from sqlalchemy.ext.asyncio import AsyncSession -from modules.features.chatBot.subChatbotDatabase import get_async_db_session -from modules.features.chatBot.mainChatBot import ( - get_or_create_thread_for_user, - ) -from modules.datamodels.datamodelChatbot import ( - ChatMessageRequest, - MessageItem, - ChatMessageResponse, - ThreadSummary, - ThreadListResponse, - ThreadDetail, - RenameThreadRequest, - DeleteResponse, - ToolListResponse, - ToolInfo, - GrantToolRequest, - GrantToolResponse, - RevokeToolRequest, - RevokeToolResponse, - UpdateToolRequest, - UpdateToolResponse, - ) -from modules.features.chatBot import mainChatBot as chat_service - - -logger = logging.getLogger(__name__) - -router = APIRouter( - prefix="/api/chatbot", - tags=["Chatbot"], - responses={404: {"description": "Not found"}}, -) - - -# --- Actual endpoints for chatbot --- - - -@router.post("/message/stream") -@limiter.limit("30/minute") -async def post_chat_message_stream( - *, - request: Request, - message_request: ChatMessageRequest, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> StreamingResponse: - """ - Post a message to a chat thread with streaming progress updates. - Creates a new thread if thread_id is not provided. - - Returns Server-Sent Events (SSE) stream with status updates and final response. - """ - try: - # Validate and get tools for the request - tool_ids = await chat_service.validate_and_get_tools_for_request( - user_id=currentUser.id, - requested_tool_ids=message_request.tools, - session=session, - ) - - # Get or create thread using helper function - thread_id = await get_or_create_thread_for_user( - thread_id=message_request.thread_id, - user=currentUser, - session=session, - thread_name=message_request.message[:100], - refresh_date_updated=True, - ) - - logger.info( - f"User {currentUser.id} posted streaming message to thread {thread_id}" - ) - - return StreamingResponse( - chat_service.post_message_stream( - thread_id=thread_id, - message=message_request.message, - user=currentUser, - tool_ids=tool_ids, - ), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - - except PermissionError as e: - logger.error(f"Permission error: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=str(e) or "Permission denied", - ) - except ValueError as e: - logger.error(f"Validation error: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=str(e) or "Permission denied", - ) - except Exception as e: - logger.error( - f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.post("/message", response_model=ChatMessageResponse) -@limiter.limit("30/minute") -async def post_chat_message( - *, - request: Request, - message_request: ChatMessageRequest, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> ChatMessageResponse: - """ - Post a message to a chat thread and get assistant response (non-streaming). - Creates a new thread if thread_id is not provided. - - For streaming updates, use the /message/stream endpoint instead. - """ - try: - # Validate and get tools for the request - tool_ids = await chat_service.validate_and_get_tools_for_request( - user_id=currentUser.id, - requested_tool_ids=message_request.tools, - session=session, - ) - - # Get or create thread using helper function - thread_id = await get_or_create_thread_for_user( - thread_id=message_request.thread_id, - user=currentUser, - session=session, - thread_name=message_request.message[:100], - refresh_date_updated=True, - ) - - logger.info(f"User {currentUser.id} posted message to thread {thread_id}") - - response = await chat_service.post_message( - thread_id=thread_id, - message=message_request.message, - user=currentUser, - tool_ids=tool_ids, - ) - - return response - - except PermissionError as e: - logger.error(f"Permission error: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=str(e) or "Permission denied", - ) - except ValueError as e: - logger.error(f"Validation error: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=str(e) or "Permission denied", - ) - except Exception as e: - logger.error( - f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.get("/threads", response_model=ThreadListResponse) -@limiter.limit("30/minute") -async def get_all_threads( - *, - request: Request, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> ThreadListResponse: - """ - Get all chat threads for the current user. - """ - try: - # Get all threads for the current user - threads = await chat_service.get_all_threads_for_user( - user=currentUser, session=session - ) - - logger.info(f"User {currentUser.id} retrieved {len(threads)} threads") - - return ThreadListResponse(threads=threads) - - except Exception as e: - logger.error( - f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.get("/threads/{thread_id}", response_model=ThreadDetail) -@limiter.limit("30/minute") -async def get_thread_by_id( - *, - request: Request, - thread_id: str, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> ThreadDetail: - """ - Get a specific chat thread with all its messages from LangGraph checkpointer. - """ - try: - thread_detail = await chat_service.get_thread_detail_for_user( - thread_id=thread_id, - user=currentUser, - session=session, - ) - - logger.info(f"User {currentUser.id} retrieved thread {thread_id}") - return thread_detail - - except ValueError as e: - logger.error(f"Thread not found: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=str(e) or "Thread not found", - ) - except PermissionError as e: - logger.error(f"Permission denied: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=str(e) or "Permission denied", - ) - except Exception as e: - logger.error( - f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}", - exc_info=True, - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.patch("/threads/{thread_id}", response_model=DeleteResponse) -@limiter.limit("30/minute") -async def rename_thread( - *, - request: Request, - thread_id: str, - rename_request: RenameThreadRequest, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> DeleteResponse: - """ - Rename a chat thread. - """ - try: - await chat_service.update_thread_name( - thread_id=thread_id, - user=currentUser, - new_thread_name=rename_request.new_name, - session=session, - ) - - logger.info( - f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'" - ) - - return DeleteResponse( - message=f"Thread {thread_id} successfully renamed", - thread_id=thread_id, - ) - - except ValueError as e: - logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=str(e) or "Thread not found or permission denied", - ) - except Exception as e: - logger.error( - f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}", - exc_info=True, - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.delete("/threads/{thread_id}", response_model=DeleteResponse) -@limiter.limit("10/minute") -async def delete_thread( - *, - request: Request, - thread_id: str, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> DeleteResponse: - """ - Delete a chat thread and all its associated data from both LangGraph and database. - """ - try: - await chat_service.delete_thread_for_user( - thread_id=thread_id, - user=currentUser, - session=session, - ) - - logger.info(f"User {currentUser.id} deleted thread {thread_id}") - - return DeleteResponse( - message=f"Thread {thread_id} successfully deleted", - thread_id=thread_id, - ) - - except ValueError as e: - logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=str(e) or "Thread not found or permission denied", - ) - except PermissionError as e: - logger.error(f"Permission denied: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=str(e) or "Permission denied", - ) - except Exception as e: - logger.error( - f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}", - exc_info=True, - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -# Tool Management Endpoints - - -@router.get("/tools", response_model=ToolListResponse) -@limiter.limit("30/minute") -async def get_all_tools( - *, - request: Request, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> ToolListResponse: - """ - Get all available chatbot tools. - Only accessible to system administrators. - """ - try: - # Check SYSADMIN permission - if currentUser.privilege != UserPrivilege.SYSADMIN: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only system administrators can view tools", - ) - - # Get all tools from service - tools_data = await chat_service.get_all_tools(session=session) - - # Convert to ToolInfo objects - tools = [ToolInfo(**tool) for tool in tools_data] - - logger.info(f"User {currentUser.id} retrieved {len(tools)} tools") - - return ToolListResponse(tools=tools) - - except HTTPException: - raise - except Exception as e: - logger.error( - f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.post("/tools/grant", response_model=GrantToolResponse) -@limiter.limit("10/minute") -async def grant_tool_to_user( - *, - request: Request, - grant_request: GrantToolRequest, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> GrantToolResponse: - """ - Grant a tool to a user. - Only accessible to system administrators. - """ - try: - # Check SYSADMIN permission - if currentUser.privilege != UserPrivilege.SYSADMIN: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only system administrators can grant tools", - ) - - # Grant the tool - await chat_service.grant_tool_to_user( - user_id=grant_request.user_id, - tool_id=grant_request.tool_id, - session=session, - ) - - logger.info( - f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}" - ) - - return GrantToolResponse( - message=f"Tool successfully granted to user {grant_request.user_id}", - user_id=grant_request.user_id, - tool_id=grant_request.tool_id, - ) - - except ValueError as e: - logger.error(f"Validation error: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e) or "Invalid request", - ) - except HTTPException: - raise - except Exception as e: - logger.error( - f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.delete("/tools/revoke", response_model=RevokeToolResponse) -@limiter.limit("10/minute") -async def revoke_tool_from_user( - *, - request: Request, - revoke_request: RevokeToolRequest, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> RevokeToolResponse: - """ - Revoke a tool from a user. - Only accessible to system administrators. - """ - try: - # Check SYSADMIN permission - if currentUser.privilege != UserPrivilege.SYSADMIN: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only system administrators can revoke tools", - ) - - # Revoke the tool - await chat_service.revoke_tool_from_user( - user_id=revoke_request.user_id, - tool_id=revoke_request.tool_id, - session=session, - ) - - logger.info( - f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}" - ) - - return RevokeToolResponse( - message=f"Tool successfully revoked from user {revoke_request.user_id}", - user_id=revoke_request.user_id, - tool_id=revoke_request.tool_id, - ) - - except ValueError as e: - logger.error(f"Validation error: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e) or "Invalid request", - ) - except HTTPException: - raise - except Exception as e: - logger.error( - f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse) -@limiter.limit("10/minute") -async def update_tool( - *, - request: Request, - tool_id: str, - update_request: UpdateToolRequest, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> UpdateToolResponse: - """ - Update a tool's label and/or description. - Only accessible to system administrators. - """ - try: - # Check SYSADMIN permission - if currentUser.privilege != UserPrivilege.SYSADMIN: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only system administrators can update tools", - ) - - # Update the tool - updated_fields = await chat_service.update_tool( - tool_id=tool_id, - label=update_request.label, - description=update_request.description, - session=session, - ) - - logger.info( - f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}" - ) - - return UpdateToolResponse( - message="Tool successfully updated", - tool_id=tool_id, - updated_fields=updated_fields, - ) - - except ValueError as e: - logger.error(f"Validation error: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e) or "Invalid request", - ) - except HTTPException: - raise - except Exception as e: - logger.error( - f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.get("/tools/user/{user_id}", response_model=ToolListResponse) -@limiter.limit("30/minute") -async def get_tools_for_specific_user( - *, - request: Request, - user_id: str, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> ToolListResponse: - """ - Get all tools granted to a specific user. - Only accessible to system administrators. - """ - try: - # Check SYSADMIN permission - if currentUser.privilege != UserPrivilege.SYSADMIN: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only system administrators can view user tools", - ) - - # Get tools for the specified user - tools_data = await chat_service.get_tools_for_user( - user_id=user_id, session=session - ) - - # Convert to ToolInfo objects - tools = [ToolInfo(**tool) for tool in tools_data] - - logger.info( - f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}" - ) - - return ToolListResponse(tools=tools) - - except HTTPException: - raise - except Exception as e: - logger.error( - f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}", - exc_info=True, - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) - - -@router.get("/tools/me", response_model=ToolListResponse) -@limiter.limit("30/minute") -async def get_my_tools( - *, - request: Request, - currentUser: User = Depends(getCurrentUser), - session: AsyncSession = Depends(get_async_db_session), -) -> ToolListResponse: - """ - Get all tools the current user has access to. - """ - try: - # Get tools for the current user - tools_data = await chat_service.get_tools_for_user( - user_id=currentUser.id, session=session - ) - - # Convert to ToolInfo objects - tools = [ToolInfo(**tool) for tool in tools_data] - - logger.info( - f"User {currentUser.id} retrieved {len(tools)} tools for themselves" - ) - - return ToolListResponse(tools=tools) - - except Exception as e: - logger.error( - f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}", - exc_info=True, - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}", - ) diff --git a/modules/workflows/workflowManager.py b/modules/workflows/workflowManager.py index e220fe69..6c227f10 100644 --- a/modules/workflows/workflowManager.py +++ b/modules/workflows/workflowManager.py @@ -1,6 +1,5 @@ from typing import Dict, Any, List, Optional import logging -from datetime import datetime, UTC import uuid import asyncio import json