From 86567f93e8eeba6b78a9a265bafd5cce3c8488fe Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 1 Oct 2025 11:24:51 +0200 Subject: [PATCH 01/29] chore: gracefully handle favicon issue --- modules/routes/routeAdmin.py | 44 ++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/modules/routes/routeAdmin.py b/modules/routes/routeAdmin.py index 6fa23ff6..992ad596 100644 --- a/modules/routes/routeAdmin.py +++ b/modules/routes/routeAdmin.py @@ -21,13 +21,14 @@ os.makedirs(staticFolder, exist_ok=True) logger = logging.getLogger(__name__) router = APIRouter( - prefix="", - tags=["Administration"], - responses={404: {"description": "Not found"}} + prefix="", tags=["Administration"], responses={404: {"description": "Not found"}} ) # Mount static files -router.mount("/static", StaticFiles(directory=str(staticFolder), html=True), name="static") +router.mount( + "/static", StaticFiles(directory=str(staticFolder), html=True), name="static" +) + @router.get("/") @limiter.limit("30/minute") @@ -36,31 +37,40 @@ async def root(request: Request) -> Dict[str, str]: # Validate required configuration values allowedOrigins = APP_CONFIG.get("APP_ALLOWED_ORIGINS") if not allowedOrigins: - raise HTTPException(status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required") - + raise HTTPException( + status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required" + ) + return { "status": "online", "message": "Data Platform API is active", - "allowedOrigins": f"Allowed origins are {allowedOrigins}" + "allowedOrigins": f"Allowed origins are {allowedOrigins}", } + @router.get("/api/environment") @limiter.limit("30/minute") -async def get_environment(request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, str]: +async def get_environment( + request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser) +) -> Dict[str, str]: """Get environment configuration for frontend""" # Validate required configuration values apiBaseUrl = APP_CONFIG.get("APP_API_URL") if not apiBaseUrl: - raise HTTPException(status_code=500, detail="APP_API_URL configuration is required") - + raise HTTPException( + status_code=500, detail="APP_API_URL configuration is required" + ) + environment = APP_CONFIG.get("APP_ENV") if not environment: raise HTTPException(status_code=500, detail="APP_ENV configuration is required") - + instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") if not instanceLabel: - raise HTTPException(status_code=500, detail="APP_ENV_LABEL configuration is required") - + raise HTTPException( + status_code=500, detail="APP_ENV_LABEL configuration is required" + ) + return { "apiBaseUrl": apiBaseUrl, "environment": environment, @@ -68,13 +78,17 @@ async def get_environment(request: Request, currentUser: Dict[str, Any] = Depend # Add other environment variables the frontend might need } + @router.options("/{fullPath:path}") @limiter.limit("60/minute") async def options_route(request: Request, fullPath: str) -> Response: return Response(status_code=200) + @router.get("/favicon.ico") @limiter.limit("30/minute") async def favicon(request: Request) -> FileResponse: - return FileResponse(str(staticFolder / "favicon.ico"), media_type="image/x-icon") - + favicon_path = staticFolder / "favicon.ico" + if not favicon_path.exists(): + raise HTTPException(status_code=404, detail="Favicon not found") + return FileResponse(str(favicon_path), media_type="image/x-icon") From 68d6ab9890bff4ff66bf918d08c95769b12400bd Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 1 Oct 2025 16:00:19 +0200 Subject: [PATCH 02/29] feat: add chatbot dummy router --- app.py | 155 +++++++---- .../chatbotTools/customerTools/__init__.py | 1 + .../chatbotTools/sharedTools/__init__.py | 1 + modules/features/chatBot/utils/permissions.py | 7 + modules/routes/routeChatbot.py | 254 ++++++++++++++++++ 5 files changed, 373 insertions(+), 45 deletions(-) create mode 100644 modules/features/chatBot/chatbotTools/customerTools/__init__.py create mode 100644 modules/features/chatBot/chatbotTools/sharedTools/__init__.py create mode 100644 modules/features/chatBot/utils/permissions.py create mode 100644 modules/routes/routeChatbot.py diff --git a/app.py b/app.py index 30def90e..ed4e7214 100644 --- a/app.py +++ b/app.py @@ -1,10 +1,11 @@ import os + os.environ["NUMEXPR_MAX_THREADS"] = "12" from fastapi import FastAPI, HTTPException, Depends, Body, status, Response from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager - + import logging from logging.handlers import RotatingFileHandler @@ -20,32 +21,36 @@ class DailyRotatingFileHandler(RotatingFileHandler): A rotating file handler that automatically switches to a new file when the date changes. The log file name includes the current date and switches at midnight. """ - - def __init__(self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs): + + def __init__( + self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs + ): self.log_dir = log_dir self.filename_prefix = filename_prefix self.current_date = None self.current_file = None - + # Initialize with today's file self._update_file_if_needed() - + # Call parent constructor with current file - super().__init__(self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs) - + super().__init__( + self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs + ) + def _update_file_if_needed(self): """Update the log file if the date has changed""" today = datetime.now().strftime("%Y%m%d") - + if self.current_date != today: self.current_date = today new_file = os.path.join(self.log_dir, f"{self.filename_prefix}_{today}.log") - + if self.current_file != new_file: self.current_file = new_file return True return False - + def emit(self, record): """Emit a log record, switching files if date has changed""" # Check if we need to switch to a new file @@ -54,16 +59,17 @@ class DailyRotatingFileHandler(RotatingFileHandler): if self.stream: self.stream.close() self.stream = None - + # Update the baseFilename for the parent class self.baseFilename = self.current_file # Reopen the stream if not self.delay: self.stream = self._open() - + # Call parent emit method super().emit(record) + def initLogging(): """Initialize logging with configuration from APP_CONFIG""" # Get log level from config (default to INFO if not found) @@ -76,33 +82,39 @@ def initLogging(): # If relative path, make it relative to the gateway directory gatewayDir = os.path.dirname(os.path.abspath(__file__)) logDir = os.path.join(gatewayDir, logDir) - + # Ensure log directory exists os.makedirs(logDir, exist_ok=True) # Create formatters - using single line format consoleFormatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S") + datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"), ) - + # File formatter with more detailed error information but still single line fileFormatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s - %(pathname)s:%(lineno)d - %(funcName)s", - datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S") + datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"), ) # Add filter to exclude Chrome DevTools requests class ChromeDevToolsFilter(logging.Filter): def filter(self, record): - return not (isinstance(record.msg, str) and - ('.well-known/appspecific/com.chrome.devtools.json' in record.msg or - 'Request: /index.html' in record.msg)) + return not ( + isinstance(record.msg, str) + and ( + ".well-known/appspecific/com.chrome.devtools.json" in record.msg + or "Request: /index.html" in record.msg + ) + ) # Add filter to exclude all httpcore loggers (including sub-loggers) class HttpcoreStarFilter(logging.Filter): def filter(self, record): - return not (record.name == 'httpcore' or record.name.startswith('httpcore.')) + return not ( + record.name == "httpcore" or record.name.startswith("httpcore.") + ) # Add filter to exclude HTTP debug messages class HTTPDebugFilter(logging.Filter): @@ -110,14 +122,14 @@ def initLogging(): if isinstance(record.msg, str): # Filter out HTTP debug messages http_debug_patterns = [ - 'receive_response_body.started', - 'receive_response_body.complete', - 'response_closed.started', - '_send_single_request', - 'httpcore.http11', - 'httpx._client', - 'HTTP Request', - 'multipart.multipart' + "receive_response_body.started", + "receive_response_body.complete", + "response_closed.started", + "_send_single_request", + "httpcore.http11", + "httpx._client", + "HTTP Request", + "multipart.multipart", ] return not any(pattern in record.msg for pattern in http_debug_patterns) return True @@ -129,8 +141,21 @@ def initLogging(): # Remove only emojis, preserve other Unicode characters like quotes import re import unicodedata + # Remove emoji characters specifically - record.msg = ''.join(char for char in record.msg if unicodedata.category(char) != 'So' or not (0x1F600 <= ord(char) <= 0x1F64F or 0x1F300 <= ord(char) <= 0x1F5FF or 0x1F680 <= ord(char) <= 0x1F6FF or 0x1F1E0 <= ord(char) <= 0x1F1FF or 0x2600 <= ord(char) <= 0x26FF or 0x2700 <= ord(char) <= 0x27BF)) + record.msg = "".join( + char + for char in record.msg + if unicodedata.category(char) != "So" + or not ( + 0x1F600 <= ord(char) <= 0x1F64F + or 0x1F300 <= ord(char) <= 0x1F5FF + or 0x1F680 <= ord(char) <= 0x1F6FF + or 0x1F1E0 <= ord(char) <= 0x1F1FF + or 0x2600 <= ord(char) <= 0x26FF + or 0x2700 <= ord(char) <= 0x27BF + ) + ) return True # Configure handlers based on config @@ -149,14 +174,16 @@ def initLogging(): # Add file handler if enabled if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True): # Create daily application log file with automatic date switching - rotationSize = int(APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)) # Default: 10MB + rotationSize = int( + APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760) + ) # Default: 10MB backupCount = int(APP_CONFIG.get("APP_LOGGING_BACKUP_COUNT", 5)) - + fileHandler = DailyRotatingFileHandler( log_dir=logDir, filename_prefix="log_app", max_bytes=rotationSize, - backup_count=backupCount + backup_count=backupCount, ) fileHandler.setFormatter(fileFormatter) fileHandler.addFilter(ChromeDevToolsFilter()) @@ -171,11 +198,18 @@ def initLogging(): format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"), handlers=handlers, - force=True # Force reconfiguration of the root logger + force=True, # Force reconfiguration of the root logger ) # Silence noisy third-party libraries - use the same level as the root logger - noisyLoggers = ["httpx", "httpcore", "urllib3", "asyncio", "fastapi.security.oauth2", "msal"] + noisyLoggers = [ + "httpx", + "httpcore", + "urllib3", + "asyncio", + "fastapi.security.oauth2", + "msal", + ] for loggerName in noisyLoggers: logging.getLogger(loggerName).setLevel(logging.WARNING) @@ -183,21 +217,25 @@ def initLogging(): logger = logging.getLogger(__name__) logger.info(f"Logging initialized with level {logLevelName}") logger.info(f"Log directory: {logDir}") - - if APP_CONFIG.get('APP_LOGGING_FILE_ENABLED', True): + + if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True): today = datetime.now().strftime("%Y%m%d") appLogFile = os.path.join(logDir, f"log_app_{today}.log") logger.info(f"Application log file: {appLogFile} (auto-switches daily)") else: logger.info("Application log file: disabled") - - logger.info(f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}") + + logger.info( + f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}" + ) + # Initialize logging initLogging() logger = logging.getLogger(__name__) instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") + # Define lifespan context manager for application startup/shutdown events @asynccontextmanager async def lifespan(app: FastAPI): @@ -210,11 +248,12 @@ async def lifespan(app: FastAPI): # START APP app = FastAPI( - title="PowerOn | Data Platform API", + title="PowerOn | Data Platform API", description=f"Backend API for the Multi-Agent Platform by ValueOn AG ({instanceLabel})", - lifespan=lifespan + lifespan=lifespan, ) + # Parse CORS origins from environment variable def get_allowed_origins(): origins_str = APP_CONFIG.get("APP_ALLOWED_ORIGINS", "http://localhost:8080") @@ -223,73 +262,99 @@ def get_allowed_origins(): logger.info(f"CORS allowed origins: {origins}") return origins + # CORS configuration using environment variables app.add_middleware( CORSMiddleware, - allow_origins= get_allowed_origins(), + allow_origins=get_allowed_origins(), allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], expose_headers=["*"], - max_age=86400 # Increased caching for preflight requests + max_age=86400, # Increased caching for preflight requests ) # CSRF protection middleware from modules.security.csrf import CSRFMiddleware -from modules.security.tokenRefreshMiddleware import TokenRefreshMiddleware, ProactiveTokenRefreshMiddleware +from modules.security.tokenRefreshMiddleware import ( + TokenRefreshMiddleware, + ProactiveTokenRefreshMiddleware, +) + app.add_middleware(CSRFMiddleware) # Token refresh middleware (silent refresh for expired OAuth tokens) app.add_middleware(TokenRefreshMiddleware, enabled=True) # Proactive token refresh middleware (refresh tokens before they expire) -app.add_middleware(ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5) +app.add_middleware( + ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5 +) # Run triggered features import modules.features.init # Include all routers from modules.routes.routeAdmin import router as generalRouter + app.include_router(generalRouter) from modules.routes.routeAttributes import router as attributesRouter + app.include_router(attributesRouter) from modules.routes.routeDataMandates import router as mandateRouter + app.include_router(mandateRouter) from modules.routes.routeDataUsers import router as userRouter + app.include_router(userRouter) from modules.routes.routeDataFiles import router as fileRouter + app.include_router(fileRouter) from modules.routes.routeDataNeutralization import router as neutralizationRouter + app.include_router(neutralizationRouter) from modules.routes.routeDataPrompts import router as promptRouter + app.include_router(promptRouter) from modules.routes.routeDataConnections import router as connectionsRouter + app.include_router(connectionsRouter) from modules.routes.routeWorkflows import router as workflowRouter + app.include_router(workflowRouter) from modules.routes.routeChatPlayground import router as chatPlaygroundRouter + app.include_router(chatPlaygroundRouter) from modules.routes.routeSecurityLocal import router as localRouter + app.include_router(localRouter) from modules.routes.routeSecurityMsft import router as msftRouter + app.include_router(msftRouter) from modules.routes.routeSecurityGoogle import router as googleRouter + app.include_router(googleRouter) from modules.routes.routeVoiceGoogle import router as voiceGoogleRouter + app.include_router(voiceGoogleRouter) from modules.routes.routeSecurityAdmin import router as adminSecurityRouter -app.include_router(adminSecurityRouter) \ No newline at end of file + +app.include_router(adminSecurityRouter) + +from modules.routes.routeChatbot import router as chatbotRouter + +app.include_router(chatbotRouter) diff --git a/modules/features/chatBot/chatbotTools/customerTools/__init__.py b/modules/features/chatBot/chatbotTools/customerTools/__init__.py new file mode 100644 index 00000000..52043b31 --- /dev/null +++ b/modules/features/chatBot/chatbotTools/customerTools/__init__.py @@ -0,0 +1 @@ +"""Tools that are shared between multiple customers go here.""" diff --git a/modules/features/chatBot/chatbotTools/sharedTools/__init__.py b/modules/features/chatBot/chatbotTools/sharedTools/__init__.py new file mode 100644 index 00000000..b0b10bb2 --- /dev/null +++ b/modules/features/chatBot/chatbotTools/sharedTools/__init__.py @@ -0,0 +1 @@ +"""Tools that are custom to a specific customer go here.""" diff --git a/modules/features/chatBot/utils/permissions.py b/modules/features/chatBot/utils/permissions.py new file mode 100644 index 00000000..45e306a5 --- /dev/null +++ b/modules/features/chatBot/utils/permissions.py @@ -0,0 +1,7 @@ +# get_allowed_tools + + +# get_allowed_models + + +# get_system_prompt diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py new file mode 100644 index 00000000..ff5fa9f4 --- /dev/null +++ b/modules/routes/routeChatbot.py @@ -0,0 +1,254 @@ +from pydantic import BaseModel, Field +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.requests import Request +from typing import Any, Dict, List, Optional +from datetime import datetime +import logging +import uuid + +from modules.datamodels.datamodelUam import User +from modules.security.auth import getCurrentUser, limiter + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/api/chatbot", + tags=["Chatbot"], + responses={404: {"description": "Not found"}}, +) + +# --- Pydantic models for requests and responses --- + + +class ChatMessageRequest(BaseModel): + """Request model for posting a chat message""" + + thread_id: Optional[str] = Field( + None, description="Thread ID (creates new thread if not provided)" + ) + message: str = Field(..., description="User message content") + + +class MessageItem(BaseModel): + """Individual message in a thread""" + + role: str = Field(..., description="Message role (user or assistant)") + content: str = Field(..., description="Message content") + timestamp: float = Field(..., description="Message timestamp (Unix timestamp)") + + +class ChatMessageResponse(BaseModel): + """Response model for posting a chat message""" + + thread_id: str = Field(..., description="Thread ID") + messages: List[MessageItem] = Field(..., description="All messages in thread") + + +class ThreadSummary(BaseModel): + """Summary of a chat thread for list view""" + + thread_id: str = Field(..., description="Thread ID") + created_at: float = Field(..., description="Thread creation timestamp") + last_message: str = Field(..., description="Last message content") + message_count: int = Field(..., description="Total number of messages") + + +class ThreadListResponse(BaseModel): + """Response model for listing all threads""" + + threads: List[ThreadSummary] = Field(..., description="List of thread summaries") + + +class ThreadDetail(BaseModel): + """Detailed view of a single thread""" + + thread_id: str = Field(..., description="Thread ID") + created_at: float = Field(..., description="Thread creation timestamp") + messages: List[MessageItem] = Field( + ..., description="All messages in chronological order" + ) + + +class DeleteResponse(BaseModel): + """Response model for delete operations""" + + message: str = Field(..., description="Confirmation message") + thread_id: str = Field(..., description="Deleted thread ID") + + +# --- Actual endpoints for chatbot --- + + +@router.post("/message", response_model=ChatMessageResponse) +@limiter.limit("30/minute") +async def post_chat_message( + *, + request: Request, + message_request: ChatMessageRequest, + currentUser: User = Depends(getCurrentUser), +) -> ChatMessageResponse: + """ + Post a message to a chat thread and get assistant response. + Creates a new thread if thread_id is not provided. + + This endpoint will later be connected to LangGraph's checkpointer. + """ + try: + # Generate or use existing thread_id + thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" + + # Get current timestamp + current_time = datetime.now().timestamp() + + # Create dummy message history + # In production, this will fetch from LangGraph's checkpointer + messages = [ + MessageItem( + role="user", content=message_request.message, timestamp=current_time + ), + MessageItem( + role="assistant", + content=f"Echo: {message_request.message} (This is a dummy response. LangGraph integration pending.)", + timestamp=current_time + 0.5, + ), + ] + + logger.info(f"User {currentUser.id} posted message to thread {thread_id}") + + return ChatMessageResponse(thread_id=thread_id, messages=messages) + + except Exception as e: + logger.error(f"Error posting chat message: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to post message: {str(e)}", + ) + + +@router.get("/threads", response_model=ThreadListResponse) +@limiter.limit("30/minute") +async def get_all_threads( + *, request: Request, currentUser: User = Depends(getCurrentUser) +) -> ThreadListResponse: + """ + Get all chat threads for the current user. + + This endpoint will later fetch from LangGraph's PostgreSQL checkpointer. + """ + try: + # Return dummy thread data + # In production, this will query LangGraph's checkpointer database + dummy_threads = [ + ThreadSummary( + thread_id="thread_001", + created_at=datetime.now().timestamp() - 86400, # 1 day ago + last_message="Hello, how can I help you?", + message_count=4, + ), + ThreadSummary( + thread_id="thread_002", + created_at=datetime.now().timestamp() - 3600, # 1 hour ago + last_message="Thank you for your help!", + message_count=8, + ), + ThreadSummary( + thread_id="thread_003", + created_at=datetime.now().timestamp() - 300, # 5 minutes ago + last_message="Can you explain this concept?", + message_count=2, + ), + ] + + logger.info(f"User {currentUser.id} retrieved {len(dummy_threads)} threads") + + return ThreadListResponse(threads=dummy_threads) + + except Exception as e: + logger.error(f"Error retrieving threads: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve threads: {str(e)}", + ) + + +@router.get("/threads/{thread_id}", response_model=ThreadDetail) +@limiter.limit("30/minute") +async def get_thread_by_id( + *, request: Request, thread_id: str, currentUser: User = Depends(getCurrentUser) +) -> ThreadDetail: + """ + Get a specific chat thread with all its messages. + + This endpoint will later fetch from LangGraph's PostgreSQL checkpointer. + """ + try: + # Return dummy thread detail + # In production, this will query LangGraph's checkpointer for the specific thread + current_time = datetime.now().timestamp() + + dummy_messages = [ + MessageItem( + role="user", + content="Hello! I need help with Python.", + timestamp=current_time - 120, + ), + MessageItem( + role="assistant", + content="Hello! I'd be happy to help you with Python. What would you like to know?", + timestamp=current_time - 119, + ), + MessageItem( + role="user", + content="How do I use list comprehensions?", + timestamp=current_time - 60, + ), + MessageItem( + role="assistant", + content="List comprehensions are a concise way to create lists. Here's an example: [x**2 for x in range(10)]", + timestamp=current_time - 59, + ), + ] + + logger.info(f"User {currentUser.id} retrieved thread {thread_id}") + + return ThreadDetail( + thread_id=thread_id, created_at=current_time - 120, messages=dummy_messages + ) + + except Exception as e: + logger.error(f"Error retrieving thread {thread_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve thread: {str(e)}", + ) + + +@router.delete("/threads/{thread_id}", response_model=DeleteResponse) +@limiter.limit("10/minute") +async def delete_thread( + *, request: Request, thread_id: str, currentUser: User = Depends(getCurrentUser) +) -> DeleteResponse: + """ + Delete a chat thread and all its associated data. + + This endpoint will later delete from LangGraph's PostgreSQL checkpointer. + """ + try: + # In production, this will: + # 1. Verify the thread belongs to the current user + # 2. Delete the thread from LangGraph's checkpointer + # 3. Clean up any associated data + + logger.info(f"User {currentUser.id} deleted thread {thread_id}") + + return DeleteResponse( + message=f"Thread {thread_id} successfully deleted (dummy response)", + thread_id=thread_id, + ) + + except Exception as e: + logger.error(f"Error deleting thread {thread_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete thread: {str(e)}", + ) From 98b258ae53ca04abf20c32e055e371a3e65b01c8 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 09:48:32 +0200 Subject: [PATCH 03/29] feat: add langgraph first tool; pydantic v2 --- modules/datamodels/datamodelSecurity.py | 101 ++- .../features/chatBot/chatbotTools/__init__.py | 1 + .../chatbotTools/sharedTools/__init__.py | 8 +- .../sharedTools/toolTavilySearch.py | 55 ++ modules/interfaces/interfaceDbAppObjects.py | 700 ++++++++++-------- modules/shared/attributeUtils.py | 278 ++++--- requirements.txt | 7 +- 7 files changed, 718 insertions(+), 432 deletions(-) create mode 100644 modules/features/chatBot/chatbotTools/__init__.py create mode 100644 modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py diff --git a/modules/datamodels/datamodelSecurity.py b/modules/datamodels/datamodelSecurity.py index ff6a3f6f..fa8b8ed7 100644 --- a/modules/datamodels/datamodelSecurity.py +++ b/modules/datamodels/datamodelSecurity.py @@ -1,7 +1,7 @@ """Security models: Token and AuthEvent.""" from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from modules.shared.attributeUtils import register_model_labels, ModelMixin from modules.shared.timezoneUtils import get_utc_timestamp from .datamodelUam import AuthAuthority @@ -18,21 +18,36 @@ class Token(BaseModel, ModelMixin): id: Optional[str] = None userId: str authority: AuthAuthority - connectionId: Optional[str] = Field(None, description="ID of the connection this token belongs to") + connectionId: Optional[str] = Field( + None, description="ID of the connection this token belongs to" + ) tokenAccess: str tokenType: str = "bearer" - expiresAt: float = Field(description="When the token expires (UTC timestamp in seconds)") + expiresAt: float = Field( + description="When the token expires (UTC timestamp in seconds)" + ) tokenRefresh: Optional[str] = None - createdAt: Optional[float] = Field(None, description="When the token was created (UTC timestamp in seconds)") - status: TokenStatus = Field(default=TokenStatus.ACTIVE, description="Token status: active/revoked") - revokedAt: Optional[float] = Field(None, description="When the token was revoked (UTC timestamp in seconds)") - revokedBy: Optional[str] = Field(None, description="User ID who revoked the token (admin/self)") + createdAt: Optional[float] = Field( + None, description="When the token was created (UTC timestamp in seconds)" + ) + status: TokenStatus = Field( + default=TokenStatus.ACTIVE, description="Token status: active/revoked" + ) + revokedAt: Optional[float] = Field( + None, description="When the token was revoked (UTC timestamp in seconds)" + ) + revokedBy: Optional[str] = Field( + None, description="User ID who revoked the token (admin/self)" + ) reason: Optional[str] = Field(None, description="Optional revocation reason") - sessionId: Optional[str] = Field(None, description="Logical session grouping for logout revocation") - mandateId: Optional[str] = Field(None, description="Mandate ID for tenant scoping of the token") + sessionId: Optional[str] = Field( + None, description="Logical session grouping for logout revocation" + ) + mandateId: Optional[str] = Field( + None, description="Mandate ID for tenant scoping of the token" + ) - class Config: - use_enum_values = True + model_config = ConfigDict(use_enum_values=True) register_model_labels( @@ -59,14 +74,60 @@ register_model_labels( class AuthEvent(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the auth event", frontend_type="text", frontend_readonly=True, frontend_required=False) - userId: str = Field(description="ID of the user this event belongs to", frontend_type="text", frontend_readonly=True, frontend_required=True) - eventType: str = Field(description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')", frontend_type="text", frontend_readonly=True, frontend_required=True) - timestamp: float = Field(default_factory=get_utc_timestamp, description="Unix timestamp when the event occurred", frontend_type="datetime", frontend_readonly=True, frontend_required=True) - ipAddress: Optional[str] = Field(default=None, description="IP address from which the event originated", frontend_type="text", frontend_readonly=True, frontend_required=False) - userAgent: Optional[str] = Field(default=None, description="User agent string from the request", frontend_type="text", frontend_readonly=True, frontend_required=False) - success: bool = Field(default=True, description="Whether the authentication event was successful", frontend_type="boolean", frontend_readonly=True, frontend_required=True) - details: Optional[str] = Field(default=None, description="Additional details about the event", frontend_type="text", frontend_readonly=True, frontend_required=False) + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique ID of the auth event", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + userId: str = Field( + description="ID of the user this event belongs to", + frontend_type="text", + frontend_readonly=True, + frontend_required=True, + ) + eventType: str = Field( + description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')", + frontend_type="text", + frontend_readonly=True, + frontend_required=True, + ) + timestamp: float = Field( + default_factory=get_utc_timestamp, + description="Unix timestamp when the event occurred", + frontend_type="datetime", + frontend_readonly=True, + frontend_required=True, + ) + ipAddress: Optional[str] = Field( + default=None, + description="IP address from which the event originated", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + userAgent: Optional[str] = Field( + default=None, + description="User agent string from the request", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + success: bool = Field( + default=True, + description="Whether the authentication event was successful", + frontend_type="boolean", + frontend_readonly=True, + frontend_required=True, + ) + details: Optional[str] = Field( + default=None, + description="Additional details about the event", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) register_model_labels( @@ -83,5 +144,3 @@ register_model_labels( "details": {"en": "Details", "fr": "Détails"}, }, ) - - diff --git a/modules/features/chatBot/chatbotTools/__init__.py b/modules/features/chatBot/chatbotTools/__init__.py new file mode 100644 index 00000000..2bd4359d --- /dev/null +++ b/modules/features/chatBot/chatbotTools/__init__.py @@ -0,0 +1 @@ +"""Contains all tools available for the chatbot to use.""" diff --git a/modules/features/chatBot/chatbotTools/sharedTools/__init__.py b/modules/features/chatBot/chatbotTools/sharedTools/__init__.py index b0b10bb2..9b0ab5b7 100644 --- a/modules/features/chatBot/chatbotTools/sharedTools/__init__.py +++ b/modules/features/chatBot/chatbotTools/sharedTools/__init__.py @@ -1 +1,7 @@ -"""Tools that are custom to a specific customer go here.""" +"""Shared tools available across all chatbot implementations.""" + +from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import ( + tavily_search, +) + +__all__ = ["tavily_search"] diff --git a/modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py b/modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py new file mode 100644 index 00000000..e3a6c8fc --- /dev/null +++ b/modules/features/chatBot/chatbotTools/sharedTools/toolTavilySearch.py @@ -0,0 +1,55 @@ +"""Tavily Search Tool for LangGraph. + +This tool provides web search capabilities using the Tavily API. +""" + +import logging +from typing import Annotated +from langchain_core.tools import tool +from modules.connectors.connectorAiTavily import ConnectorWeb + +logger = logging.getLogger(__name__) + + +@tool +async def tavily_search( + query: Annotated[str, "The search query to look up on the web"], +) -> str: + """Search the web using Tavily API. + + Use this tool to search for current information, news, or any web content. + The tool returns relevant search results including titles and URLs. + + Args: + query: The search query string + + Returns: + A formatted string containing search results with titles and URLs + """ + try: + # Create connector instance + connector = await ConnectorWeb.create() + + # Perform search with default parameters + results = await connector._search( + query=query, + max_results=5, + search_depth="basic", + include_answer=True, + include_raw_content=False, + ) + + # Format results + if not results: + return f"No results found for query: {query}" + + formatted_results = [f"Search results for '{query}':\n"] + for i, result in enumerate(results, 1): + formatted_results.append(f"{i}. {result.title}") + formatted_results.append(f" URL: {result.url}\n") + + return "\n".join(formatted_results) + + except Exception as e: + logger.error(f"Error in tavily_search tool: {str(e)}") + return f"Error performing search: {str(e)}" diff --git a/modules/interfaces/interfaceDbAppObjects.py b/modules/interfaces/interfaceDbAppObjects.py index 251452f5..36a07484 100644 --- a/modules/interfaces/interfaceDbAppObjects.py +++ b/modules/interfaces/interfaceDbAppObjects.py @@ -18,11 +18,19 @@ from modules.shared.configuration import APP_CONFIG from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp from modules.interfaces.interfaceDbAppAccess import AppAccess from modules.datamodels.datamodelUam import ( - User, Mandate, UserInDB, UserConnection, - AuthAuthority, UserPrivilege, ConnectionStatus, + User, + Mandate, + UserInDB, + UserConnection, + AuthAuthority, + UserPrivilege, + ConnectionStatus, ) from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus -from modules.datamodels.datamodelNeutralizer import DataNeutraliserConfig, DataNeutralizerAttributes +from modules.datamodels.datamodelNeutralizer import ( + DataNeutraliserConfig, + DataNeutralizerAttributes, +) logger = logging.getLogger(__name__) @@ -35,12 +43,13 @@ _rootAppObjects = None # Password-Hashing pwdContext = CryptContext(schemes=["argon2"], deprecated="auto") + class AppObjects: """ Interface to the Gateway system. Manages users and mandates. """ - + def __init__(self, currentUser: Optional[User] = None): """Initializes the Gateway Interface.""" # Initialize variables @@ -48,47 +57,49 @@ class AppObjects: self.userId = currentUser.id if currentUser else None self.mandateId = currentUser.mandateId if currentUser else None self.access = None # Will be set when user context is provided - + # Initialize database self._initializeDatabase() - + # Initialize standard records if needed self._initRecords() - + # Set user context if provided if currentUser: self.setUserContext(currentUser) - + def setUserContext(self, currentUser: User): """Sets the user context for the interface.""" if not currentUser: logger.info("Initializing interface without user context") return - + self.currentUser = currentUser # Store User object directly self.userId = currentUser.id self.mandateId = currentUser.mandateId - + if not self.userId or not self.mandateId: raise ValueError("Invalid user context: id and mandateId are required") - + # Add language settings self.userLanguage = currentUser.language # Default user language - + # Initialize access control with user context - self.access = AppAccess(self.currentUser, self.db) # Convert to dict only when needed - + self.access = AppAccess( + self.currentUser, self.db + ) # Convert to dict only when needed + # Update database context self.db.updateContext(self.userId) - + def __del__(self): """Cleanup method to close database connection.""" - if hasattr(self, 'db') and self.db is not None: + if hasattr(self, "db") and self.db is not None: try: self.db.close() except Exception as e: logger.error(f"Error closing database connection: {e}") - + def _initializeDatabase(self): """Initializes the database connection directly.""" try: @@ -98,7 +109,7 @@ class AppObjects: dbUser = APP_CONFIG.get("DB_APP_USER") dbPassword = APP_CONFIG.get("DB_APP_PASSWORD_SECRET") dbPort = int(APP_CONFIG.get("DB_APP_PORT", 5432)) - + # Create database connector directly self.db = DatabaseConnector( dbHost=dbHost, @@ -106,40 +117,36 @@ class AppObjects: dbUser=dbUser, dbPassword=dbPassword, dbPort=dbPort, - userId=self.userId + userId=self.userId, ) - + # Initialize database system self.db.initDbSystem() - + logger.info(f"Database initialized successfully for user {self.userId}") except Exception as e: logger.error(f"Failed to initialize database: {str(e)}") raise - + def _initRecords(self): """Initialize standard records if they don't exist.""" self._initRootMandate() self._initAdminUser() self._initEventUser() - + def _initRootMandate(self): """Creates the Root mandate if it doesn't exist.""" existingMandateId = self.getInitialId(Mandate) mandates = self.db.getRecordset(Mandate) if existingMandateId is None or not mandates: logger.info("Creating Root mandate") - rootMandate = Mandate( - name="Root", - language="en", - enabled=True - ) + rootMandate = Mandate(name="Root", language="en", enabled=True) createdMandate = self.db.recordCreate(Mandate, rootMandate) logger.info(f"Root mandate created with ID {createdMandate['id']}") - + # Update mandate context - self.mandateId = createdMandate['id'] - + self.mandateId = createdMandate["id"] + def _initAdminUser(self): """Creates the Admin user if it doesn't exist.""" existingUserId = self.getInitialId(UserInDB) @@ -155,12 +162,14 @@ class AppObjects: language="en", privilege=UserPrivilege.SYSADMIN, authenticationAuthority="local", # Using lowercase value directly - hashedPassword=self._getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")), - connections=[] + hashedPassword=self._getPasswordHash( + APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET") + ), + connections=[], ) createdUser = self.db.recordCreate(UserInDB, adminUser) logger.info(f"Admin user created with ID {createdUser['id']}") - + # Update user context self.currentUser = createdUser self.userId = createdUser.get("id") @@ -168,7 +177,9 @@ class AppObjects: def _initEventUser(self): """Creates the Event user if it doesn't exist.""" # Check if event user already exists - existingUsers = self.db.getRecordset(UserInDB, recordFilter={"username": "event"}) + existingUsers = self.db.getRecordset( + UserInDB, recordFilter={"username": "event"} + ) if not existingUsers: logger.info("Creating Event user") eventUser = UserInDB( @@ -180,44 +191,48 @@ class AppObjects: language="en", privilege=UserPrivilege.SYSADMIN, authenticationAuthority="local", # Using lowercase value directly - hashedPassword=self._getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")), - connections=[] + hashedPassword=self._getPasswordHash( + APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET") + ), + connections=[], ) createdUser = self.db.recordCreate(UserInDB, eventUser) logger.info(f"Event user created with ID {createdUser['id']}") - def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _uam( + self, model_class: type, recordset: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """ Unified user access management function that filters data based on user privileges and adds access control attributes. - + Args: model_class: Pydantic model class for the table recordset: Recordset to filter based on access rules - + Returns: Filtered recordset with access control attributes """ # First apply access control filteredRecords = self.access.uam(model_class, recordset) - + # Then filter out database-specific fields cleanedRecords = [] for record in filteredRecords: # Create a new dict with only non-database fields - cleanedRecord = {k: v for k, v in record.items() if not k.startswith('_')} + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} cleanedRecords.append(cleanedRecord) - + return cleanedRecords - + def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool: """ Checks if the current user can modify (create/update/delete) records in a table. - + Args: model_class: Pydantic model class for the table recordId: Optional record ID for specific record check - + Returns: Boolean indicating permission """ @@ -230,22 +245,22 @@ class AppObjects: def _getPasswordHash(self, password: str) -> str: """Creates a hash for a password.""" return pwdContext.hash(password) - + def _verifyPassword(self, plainPassword: str, hashedPassword: str) -> bool: """Checks if the password matches the hash.""" return pwdContext.verify(plainPassword, hashedPassword) - + # User methods - + def getUsersByMandate(self, mandateId: str) -> List[User]: """Returns users for a specific mandate if user has access.""" # Get users for this mandate users = self.db.getRecordset(UserInDB, recordFilter={"mandateId": mandateId}) filteredUsers = self._uam(UserInDB, users) - + # Convert to User models return [User.from_dict(user) for user in filteredUsers] - + def getUserByUsername(self, username: str) -> Optional[User]: """Returns a user by username.""" try: @@ -253,19 +268,19 @@ class AppObjects: users = self.db.getRecordset(UserInDB) if not users: return None - + # Find user by username for user_dict in users: if user_dict.get("username") == username: return User.from_dict(user_dict) - + logger.info(f"No user found with username {username}") return None - + except Exception as e: logger.error(f"Error getting user by username: {str(e)}") return None - + def getUser(self, userId: str) -> Optional[User]: """Returns a user by ID if user has access.""" try: @@ -273,7 +288,7 @@ class AppObjects: users = self.db.getRecordset(UserInDB) if not users: return None - + # Find user by ID for user_dict in users: if user_dict.get("id") == userId: @@ -282,9 +297,9 @@ class AppObjects: if filteredUsers: return User.from_dict(filteredUsers[0]) return None - + return None - + except Exception as e: logger.error(f"Error getting user by ID: {str(e)}") return None @@ -292,42 +307,50 @@ class AppObjects: def authenticateLocalUser(self, username: str, password: str) -> Optional[User]: """Authenticates a user by username and password using local authentication.""" # Clear the users table from cache and reload it - + # Get user by username user = self.getUserByUsername(username) - + if not user: raise ValueError("User not found") - + # Check if the user is enabled if not user.enabled: raise ValueError("User is disabled") - + # Verify that the user has local authentication enabled if user.authenticationAuthority != AuthAuthority.LOCAL: raise ValueError("User does not have local authentication enabled") - + # Get the full user record with password hash for verification userRecord = self.db.getRecordset(UserInDB, recordFilter={"id": user.id})[0] if not userRecord.get("hashedPassword"): raise ValueError("User has no password set") - + if not self._verifyPassword(password, userRecord["hashedPassword"]): raise ValueError("Invalid password") - + return user - def createUser(self, username: str, password: str = None, email: str = None, - fullName: str = None, language: str = "en", enabled: bool = True, - privilege: UserPrivilege = UserPrivilege.USER, - authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL, - externalId: str = None, externalUsername: str = None, - externalEmail: str = None) -> User: + def createUser( + self, + username: str, + password: str = None, + email: str = None, + fullName: str = None, + language: str = "en", + enabled: bool = True, + privilege: UserPrivilege = UserPrivilege.USER, + authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL, + externalId: str = None, + externalUsername: str = None, + externalEmail: str = None, + ) -> User: """Create a new user with optional external connection""" try: # Ensure username is a string username = str(username).strip() - + # Validate password for local authentication if authenticationAuthority == AuthAuthority.LOCAL: if not password: @@ -336,7 +359,7 @@ class AppObjects: raise ValueError("Password must be a string") if not password.strip(): raise ValueError("Password cannot be empty") - + # Create user data using UserInDB model userData = UserInDB( username=username, @@ -348,15 +371,14 @@ class AppObjects: privilege=privilege, authenticationAuthority=authenticationAuthority, hashedPassword=self._getPasswordHash(password) if password else None, - connections=[] + connections=[], ) - + # Create user record createdRecord = self.db.recordCreate(UserInDB, userData) if not createdRecord or not createdRecord.get("id"): raise ValueError("Failed to create user record") - - + # Add external connection if provided if externalId and externalUsername: self.addUserConnection( @@ -364,18 +386,20 @@ class AppObjects: authenticationAuthority, externalId, externalUsername, - externalEmail + externalEmail, ) - + # Get created user using the returned ID - createdUser = self.db.getRecordset(UserInDB, recordFilter={"id": createdRecord["id"]}) + createdUser = self.db.getRecordset( + UserInDB, recordFilter={"id": createdRecord["id"]} + ) if not createdUser or len(createdUser) == 0: raise ValueError("Failed to retrieve created user") - + # Clear cache to ensure fresh data (already done above) - + return User.from_dict(createdUser[0]) - + except ValueError as e: logger.error(f"Error creating user: {str(e)}") raise @@ -390,23 +414,22 @@ class AppObjects: user = self.getUser(userId) if not user: raise ValueError(f"User {userId} not found") - + # Update user data using model updatedData = user.to_dict() updatedData.update(updateData) updatedUser = User.from_dict(updatedData) - + # Update user record self.db.recordModify(UserInDB, userId, updatedUser) - - + # Get updated user updatedUser = self.getUser(userId) if not updatedUser: raise ValueError("Failed to retrieve updated user") - + return updatedUser - + except Exception as e: logger.error(f"Error updating user: {str(e)}") raise ValueError(f"Failed to update user: {str(e)}") @@ -414,34 +437,33 @@ class AppObjects: def disableUser(self, userId: str) -> User: """Disables a user if current user has permission.""" return self.updateUser(userId, {"enabled": False}) - + def enableUser(self, userId: str) -> User: """Enables a user if current user has permission.""" return self.updateUser(userId, {"enabled": True}) - + def _deleteUserReferencedData(self, userId: str) -> None: """Deletes all data associated with a user.""" try: - - # Delete user auth events events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId}) for event in events: self.db.recordDelete(AuthEvent, event["id"]) - + # Delete user tokens tokens = self.db.getRecordset(Token, recordFilter={"userId": userId}) for token in tokens: self.db.recordDelete(Token, token["id"]) - - + # Delete user connections - connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId}) + connections = self.db.getRecordset( + UserConnection, recordFilter={"userId": userId} + ) for conn in connections: self.db.recordDelete(UserConnection, conn["id"]) - + logger.info(f"All referenced data for user {userId} has been deleted") - + except Exception as e: logger.error(f"Error deleting referenced data for user {userId}: {str(e)}") raise @@ -453,22 +475,21 @@ class AppObjects: user = self.getUser(userId) if not user: raise ValueError(f"User {userId} not found") - + if not self._canModify(UserInDB, userId): raise PermissionError(f"No permission to delete user {userId}") - + # Delete all referenced data first self._deleteUserReferencedData(userId) - + # Delete user record success = self.db.recordDelete(UserInDB, userId) if not success: raise ValueError(f"Failed to delete user {userId}") - - + logger.info(f"User {userId} successfully deleted") return True - + except Exception as e: logger.error(f"Error deleting user: {str(e)}") raise ValueError(f"Failed to delete user: {str(e)}") @@ -479,7 +500,7 @@ class AppObjects: initialUserId = self.getInitialId(UserInDB) if not initialUserId: return None - + users = self.db.getRecordset(UserInDB, recordFilter={"id": initialUserId}) return users[0] if users else None except Exception as e: @@ -491,33 +512,24 @@ class AppObjects: try: username = checkData.get("username") authenticationAuthority = checkData.get("authenticationAuthority", "local") - + if not username: - return { - "available": False, - "message": "Username is required" - } - + return {"available": False, "message": "Username is required"} + # Get user by username user = self.getUserByUsername(username) - + # Check if user exists (User model instance) if user is not None: - return { - "available": False, - "message": "Username is already taken" - } - - return { - "available": True, - "message": "Username is available" - } - + return {"available": False, "message": "Username is already taken"} + + return {"available": True, "message": "Username is available"} + except Exception as e: logger.error(f"Error checking username availability: {str(e)}") return { "available": False, - "message": f"Error checking username availability: {str(e)}" + "message": f"Error checking username availability: {str(e)}", } # Connection methods @@ -526,8 +538,10 @@ class AppObjects: """Returns all connections for a user.""" try: # Get connections for this user - connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId}) - + connections = self.db.getRecordset( + UserConnection, recordFilter={"userId": userId} + ) + # Convert to UserConnection objects result = [] for conn_dict in connections: @@ -543,24 +557,32 @@ class AppObjects: status=conn_dict.get("status", "pending"), connectedAt=conn_dict.get("connectedAt"), lastChecked=conn_dict.get("lastChecked"), - expiresAt=conn_dict.get("expiresAt") + expiresAt=conn_dict.get("expiresAt"), ) result.append(connection) except Exception as e: - logger.error(f"Error converting connection dict to object: {str(e)}") + logger.error( + f"Error converting connection dict to object: {str(e)}" + ) continue return result - + except Exception as e: logger.error(f"Error getting user connections: {str(e)}") return [] - def addUserConnection(self, userId: str, authority: AuthAuthority, externalId: str, - externalUsername: str, externalEmail: Optional[str] = None, - status: ConnectionStatus = ConnectionStatus.PENDING) -> UserConnection: + def addUserConnection( + self, + userId: str, + authority: AuthAuthority, + externalId: str, + externalUsername: str, + externalEmail: Optional[str] = None, + status: ConnectionStatus = ConnectionStatus.PENDING, + ) -> UserConnection: """ Adds a new connection for a user. - + Args: userId: The ID of the user authority: The authentication authority (e.g., MSFT, GOOGLE) @@ -568,7 +590,7 @@ class AppObjects: externalUsername: The username from the authority externalEmail: Optional email from the authority status: The connection status (defaults to PENDING) - + Returns: The created UserConnection object """ @@ -577,7 +599,7 @@ class AppObjects: user = self.getUser(userId) if not user: raise ValueError(f"User not found: {userId}") - + # Create new connection with all required fields connection = UserConnection( id=str(uuid.uuid4()), @@ -589,15 +611,14 @@ class AppObjects: status=status, connectedAt=get_utc_timestamp(), lastChecked=get_utc_timestamp(), - expiresAt=None # Optional field, set to None by default + expiresAt=None, # Optional field, set to None by default ) - + # Save to connections table self.db.recordCreate(UserConnection, connection) - - + return connection - + except Exception as e: logger.error(f"Error adding user connection: {str(e)}") raise ValueError(f"Failed to add user connection: {str(e)}") @@ -606,93 +627,88 @@ class AppObjects: """Remove a connection to an external service""" try: # Get connection - connections = self.db.getRecordset(UserConnection, recordFilter={ - "id": connectionId - }) - + connections = self.db.getRecordset( + UserConnection, recordFilter={"id": connectionId} + ) + if not connections: raise ValueError(f"Connection {connectionId} not found") - + # Delete connection self.db.recordDelete(UserConnection, connectionId) - - + except Exception as e: logger.error(f"Error removing user connection: {str(e)}") raise ValueError(f"Failed to remove user connection: {str(e)}") # Mandate methods - + def getAllMandates(self) -> List[Mandate]: """Returns all mandates based on user access level.""" allMandates = self.db.getRecordset(Mandate) filteredMandates = self._uam(Mandate, allMandates) return [Mandate.from_dict(mandate) for mandate in filteredMandates] - + def getMandate(self, mandateId: str) -> Optional[Mandate]: """Returns a mandate by ID if user has access.""" mandates = self.db.getRecordset(Mandate, recordFilter={"id": mandateId}) if not mandates: return None - + filteredMandates = self._uam(Mandate, mandates) if not filteredMandates: return None - + return Mandate.from_dict(filteredMandates[0]) - + def createMandate(self, name: str, language: str = "en") -> Mandate: """Creates a new mandate if user has permission.""" if not self._canModify(Mandate): raise PermissionError("No permission to create mandates") - + # Create mandate data using model - mandateData = Mandate( - name=name, - language=language - ) - + mandateData = Mandate(name=name, language=language) + # Create mandate record createdRecord = self.db.recordCreate(Mandate, mandateData) if not createdRecord or not createdRecord.get("id"): raise ValueError("Failed to create mandate record") - - + return Mandate.from_dict(createdRecord) - + def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate: """Updates a mandate if user has access.""" try: # First check if user has permission to modify mandates if not self._canModify(Mandate, mandateId): raise PermissionError(f"No permission to update mandate {mandateId}") - + # Get mandate with access control mandate = self.getMandate(mandateId) if not mandate: raise ValueError(f"Mandate {mandateId} not found") - + # Update mandate data using model updatedData = mandate.to_dict() updatedData.update(updateData) updatedMandate = Mandate.from_dict(updatedData) - + # Update mandate record self.db.recordModify(Mandate, mandateId, updatedMandate) - + # Clear cache to ensure fresh data - + # Get updated mandate updatedMandate = self.getMandate(mandateId) if not updatedMandate: raise ValueError("Failed to retrieve updated mandate") - + return updatedMandate - + except Exception as e: logger.error(f"Error updating mandate: {str(e)}") raise ValueError(f"Failed to update mandate: {str(e)}") - + def deleteMandate(self, mandateId: str) -> bool: """Deletes a mandate if user has access.""" try: @@ -700,22 +716,24 @@ class AppObjects: mandate = self.getMandate(mandateId) if not mandate: return False - + if not self._canModify(Mandate, mandateId): raise PermissionError(f"No permission to delete mandate {mandateId}") - + # Check if mandate has users users = self.getUsersByMandate(mandateId) if users: - raise ValueError(f"Cannot delete mandate {mandateId} with existing users") - + raise ValueError( + f"Cannot delete mandate {mandateId} with existing users" + ) + # Delete mandate success = self.db.recordDelete(Mandate, mandateId) - + # Clear cache to ensure fresh data - + return success - + except Exception as e: logger.error(f"Error deleting mandate: {str(e)}") raise ValueError(f"Failed to delete mandate: {str(e)}") @@ -727,51 +745,64 @@ class AppObjects: try: # Validate that this is NOT a connection token if token.connectionId: - raise ValueError("Access tokens cannot have connectionId - use saveConnectionToken instead") - + raise ValueError( + "Access tokens cannot have connectionId - use saveConnectionToken instead" + ) + # Validate user context if not self.currentUser or not self.currentUser.id: raise ValueError("No valid user context available for token storage") - + # Set the user ID and mandate ID token.userId = self.currentUser.id - + # Ensure token has required fields if not token.id: token.id = str(uuid.uuid4()) if not token.createdAt: token.createdAt = get_utc_timestamp() - + # If replace_existing is True, delete old access tokens for this user and authority first if replace_existing: try: - old_tokens = self.db.getRecordset(Token, recordFilter={ - "userId": self.currentUser.id, - "authority": token.authority, - "connectionId": None # Ensure we only delete access tokens - }) + old_tokens = self.db.getRecordset( + Token, + recordFilter={ + "userId": self.currentUser.id, + "authority": token.authority, + "connectionId": None, # Ensure we only delete access tokens + }, + ) deleted_count = 0 for old_token in old_tokens: - if old_token["id"] != token.id: # Don't delete the new token if it already exists + if ( + old_token["id"] != token.id + ): # Don't delete the new token if it already exists self.db.recordDelete(Token, old_token["id"]) deleted_count += 1 - + if deleted_count > 0: - logger.info(f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}") - + logger.info( + f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}" + ) + except Exception as e: - logger.warning(f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}") + logger.warning( + f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}" + ) # Continue with saving the new token even if deletion fails - + # Convert to dict and ensure all fields are properly set - token_dict = token.dict() + token_dict = token.model_dump() + # Ensure userId is set to current user + # Convert to dict and ensure all fields are properly set + token_dict = token.model_dump() # Ensure userId is set to current user token_dict["userId"] = self.currentUser.id - + # Save to database self.db.recordCreate(Token, token_dict) - - + except Exception as e: logger.error(f"Error saving access token: {str(e)}") raise @@ -781,49 +812,56 @@ class AppObjects: try: # Validate that this IS a connection token if not token.connectionId: - raise ValueError("Connection tokens must have connectionId - use saveAccessToken instead") - + raise ValueError( + "Connection tokens must have connectionId - use saveAccessToken instead" + ) + # Validate user context if not self.currentUser or not self.currentUser.id: raise ValueError("No valid user context available for token storage") - + # Set the user ID for the connection token token.userId = self.currentUser.id - + # Ensure token has required fields if not token.id: token.id = str(uuid.uuid4()) if not token.createdAt: token.createdAt = get_utc_timestamp() - + # If replace_existing is True, delete old tokens for this connectionId first if replace_existing: try: - old_tokens = self.db.getRecordset(Token, recordFilter={ - "connectionId": token.connectionId - }) + old_tokens = self.db.getRecordset( + Token, recordFilter={"connectionId": token.connectionId} + ) deleted_count = 0 for old_token in old_tokens: - if old_token["id"] != token.id: # Don't delete the new token if it already exists + if ( + old_token["id"] != token.id + ): # Don't delete the new token if it already exists self.db.recordDelete(Token, old_token["id"]) deleted_count += 1 - + if deleted_count > 0: - logger.info(f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}") - + logger.info( + f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}" + ) + except Exception as e: - logger.warning(f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}") + logger.warning( + f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}" + ) # Continue with saving the new token even if deletion fails - + # Convert to dict and ensure all fields are properly set - token_dict = token.dict() + token_dict = token.model_dump() # Ensure userId is set to current user token_dict["userId"] = self.currentUser.id - + # Save to database self.db.recordCreate(Token, token_dict) - - + except Exception as e: logger.error(f"Error saving connection token: {str(e)}") raise @@ -834,37 +872,49 @@ class AppObjects: # Validate connectionId if not connectionId: raise ValueError("connectionId is required for getConnectionToken") - + # Get token for this specific connection # Query for specific connection - tokens = self.db.getRecordset(Token, recordFilter={ - "connectionId": connectionId - }) - - + tokens = self.db.getRecordset( + Token, recordFilter={"connectionId": connectionId} + ) + if not tokens: - logger.warning(f"No connection token found for connectionId: {connectionId}") + logger.warning( + f"No connection token found for connectionId: {connectionId}" + ) return None - + # Sort by expiration date and get the latest (most recent expiration) tokens.sort(key=lambda x: x.get("expiresAt", 0), reverse=True) latest_token = Token(**tokens[0]) - + # No auto-refresh here. Callers should use a higher-level service to refresh when needed. - + return latest_token - + except Exception as e: - logger.error(f"Error getting connection token for connectionId {connectionId}: {str(e)}") + logger.error( + f"Error getting connection token for connectionId {connectionId}: {str(e)}" + ) return None - def findActiveTokenById(self, tokenId: str, userId: str, authority: AuthAuthority, sessionId: str = None, mandateId: str = None) -> Optional[Token]: + def findActiveTokenById( + self, + tokenId: str, + userId: str, + authority: AuthAuthority, + sessionId: str = None, + mandateId: str = None, + ) -> Optional[Token]: """Find an active access token by its id (jti) with optional session/tenant scoping.""" try: recordFilter = { "id": tokenId, "userId": userId, - "authority": authority.value if hasattr(authority, 'value') else str(authority), + "authority": authority.value + if hasattr(authority, "value") + else str(authority), "status": TokenStatus.ACTIVE, } if sessionId is not None: @@ -892,7 +942,7 @@ class AppObjects: "status": TokenStatus.REVOKED, "revokedAt": get_utc_timestamp(), "revokedBy": revokedBy, - "reason": reason or "revoked" + "reason": reason or "revoked", } self.db.recordModify(Token, tokenId, tokenUpdate) return True @@ -900,30 +950,53 @@ class AppObjects: logger.error(f"Error revoking token {tokenId}: {str(e)}") return False - def revokeTokensBySessionId(self, sessionId: str, userId: str, authority: AuthAuthority, revokedBy: str, reason: str = None) -> int: + def revokeTokensBySessionId( + self, + sessionId: str, + userId: str, + authority: AuthAuthority, + revokedBy: str, + reason: str = None, + ) -> int: """Revoke all tokens of a session for a user/authority.""" try: - tokens = self.db.getRecordset(Token, recordFilter={ - "userId": userId, - "authority": authority.value if hasattr(authority, 'value') else str(authority), - "sessionId": sessionId, - "status": TokenStatus.ACTIVE - }) + tokens = self.db.getRecordset( + Token, + recordFilter={ + "userId": userId, + "authority": authority.value + if hasattr(authority, "value") + else str(authority), + "sessionId": sessionId, + "status": TokenStatus.ACTIVE, + }, + ) count = 0 for t in tokens: - self.db.recordModify(Token, t["id"], { - "status": TokenStatus.REVOKED, - "revokedAt": get_utc_timestamp(), - "revokedBy": revokedBy, - "reason": reason or "session logout" - }) + self.db.recordModify( + Token, + t["id"], + { + "status": TokenStatus.REVOKED, + "revokedAt": get_utc_timestamp(), + "revokedBy": revokedBy, + "reason": reason or "session logout", + }, + ) count += 1 return count except Exception as e: logger.error(f"Error revoking tokens for session {sessionId}: {str(e)}") return 0 - def revokeTokensByUser(self, userId: str, authority: AuthAuthority = None, mandateId: str = None, revokedBy: str = None, reason: str = None) -> int: + def revokeTokensByUser( + self, + userId: str, + authority: AuthAuthority = None, + mandateId: str = None, + revokedBy: str = None, + reason: str = None, + ) -> int: """Revoke all active tokens for a user, optionally filtered by authority/mandate.""" try: # Fetch all active tokens for user (optionally filtered by authority) @@ -932,16 +1005,22 @@ class AppObjects: "status": TokenStatus.ACTIVE, } if authority is not None: - recordFilter["authority"] = authority.value if hasattr(authority, 'value') else str(authority) + recordFilter["authority"] = ( + authority.value if hasattr(authority, "value") else str(authority) + ) tokens = self.db.getRecordset(Token, recordFilter=recordFilter) count = 0 for t in tokens: - self.db.recordModify(Token, t["id"], { - "status": TokenStatus.REVOKED, - "revokedAt": get_utc_timestamp(), - "revokedBy": revokedBy, - "reason": reason or "admin revoke" - }) + self.db.recordModify( + Token, + t["id"], + { + "status": TokenStatus.REVOKED, + "revokedAt": get_utc_timestamp(), + "revokedBy": revokedBy, + "reason": reason or "admin revoke", + }, + ) count += 1 return count except Exception as e: @@ -953,22 +1032,25 @@ class AppObjects: try: current_time = get_utc_timestamp() cleaned_count = 0 - + # Get all tokens all_tokens = self.db.getRecordset(Token, recordFilter={}) - + for token_data in all_tokens: - if token_data.get("expiresAt") and token_data.get("expiresAt") < current_time: + if ( + token_data.get("expiresAt") + and token_data.get("expiresAt") < current_time + ): # Token is expired, delete it self.db.recordDelete(Token, token_data["id"]) cleaned_count += 1 - + # Clear cache to ensure fresh data if cleaned_count > 0: logger.info(f"Cleaned up {cleaned_count} expired tokens") - + return cleaned_count - + except Exception as e: logger.error(f"Error cleaning up expired tokens: {str(e)}") return 0 @@ -981,79 +1063,92 @@ class AppObjects: self.userId = None self.mandateId = None self.access = None - + # Clear database context - if hasattr(self, 'db'): + if hasattr(self, "db"): self.db.updateContext("") - + logger.info("User logged out successfully") - + except Exception as e: logger.error(f"Error during logout: {str(e)}") raise # Neutralization methods - + def getNeutralizationConfig(self) -> Optional[DataNeutraliserConfig]: """Get the data neutralization configuration for the current user's mandate""" try: - configs = self.db.getRecordset(DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId}) + configs = self.db.getRecordset( + DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId} + ) if not configs: return None - + # Apply access control filtered_configs = self._uam(DataNeutraliserConfig, configs) if not filtered_configs: return None - + return DataNeutraliserConfig.from_dict(filtered_configs[0]) - + except Exception as e: logger.error(f"Error getting neutralization config: {str(e)}") return None - def createOrUpdateNeutralizationConfig(self, config_data: Dict[str, Any]) -> DataNeutraliserConfig: + def createOrUpdateNeutralizationConfig( + self, config_data: Dict[str, Any] + ) -> DataNeutraliserConfig: """Create or update the data neutralization configuration""" try: # Check if config already exists existing_config = self.getNeutralizationConfig() - + if existing_config: # Update existing config update_data = existing_config.to_dict() update_data.update(config_data) update_data["updatedAt"] = get_utc_timestamp() - + updated_config = DataNeutraliserConfig.from_dict(update_data) - self.db.recordModify(DataNeutraliserConfig, existing_config.id, updated_config) - + self.db.recordModify( + DataNeutraliserConfig, existing_config.id, updated_config + ) + return updated_config else: # Create new config config_data["mandateId"] = self.mandateId config_data["userId"] = self.userId - + new_config = DataNeutraliserConfig.from_dict(config_data) created_record = self.db.recordCreate(DataNeutraliserConfig, new_config) - + return DataNeutraliserConfig.from_dict(created_record) - + except Exception as e: logger.error(f"Error creating/updating neutralization config: {str(e)}") raise ValueError(f"Failed to create/update neutralization config: {str(e)}") - def getNeutralizationAttributes(self, file_id: Optional[str] = None) -> List[DataNeutralizerAttributes]: + def getNeutralizationAttributes( + self, file_id: Optional[str] = None + ) -> List[DataNeutralizerAttributes]: """Get neutralization attributes, optionally filtered by file ID""" try: filter_dict = {"mandateId": self.mandateId} if file_id: filter_dict["fileId"] = file_id - - attributes = self.db.getRecordset(DataNeutralizerAttributes, recordFilter=filter_dict) + + attributes = self.db.getRecordset( + DataNeutralizerAttributes, recordFilter=filter_dict + ) filtered_attributes = self._uam(DataNeutralizerAttributes, attributes) - - return [DataNeutralizerAttributes.from_dict(attr) for attr in filtered_attributes] - + + return [ + DataNeutralizerAttributes.from_dict(attr) + for attr in filtered_attributes + ] + except Exception as e: logger.error(f"Error getting neutralization attributes: {str(e)}") return [] @@ -1061,23 +1156,27 @@ class AppObjects: def deleteNeutralizationAttributes(self, file_id: str) -> bool: """Delete all neutralization attributes for a specific file""" try: - attributes = self.db.getRecordset(DataNeutralizerAttributes, recordFilter={ - "mandateId": self.mandateId, - "fileId": file_id - }) - + attributes = self.db.getRecordset( + DataNeutralizerAttributes, + recordFilter={"mandateId": self.mandateId, "fileId": file_id}, + ) + for attribute in attributes: self.db.recordDelete(DataNeutralizerAttributes, attribute["id"]) - - logger.info(f"Deleted {len(attributes)} neutralization attributes for file {file_id}") + + logger.info( + f"Deleted {len(attributes)} neutralization attributes for file {file_id}" + ) return True - + except Exception as e: logger.error(f"Error deleting neutralization attributes: {str(e)}") return False + # Public Methods + def getInterface(currentUser: User) -> AppObjects: """ Returns a AppObjects instance for the current user. @@ -1085,46 +1184,49 @@ def getInterface(currentUser: User) -> AppObjects: """ if not currentUser: raise ValueError("Invalid user context: user is required") - + # Create context key contextKey = f"{currentUser.mandateId}_{currentUser.id}" - + # Create new instance if not exists if contextKey not in _gatewayInterfaces: _gatewayInterfaces[contextKey] = AppObjects(currentUser) - + return _gatewayInterfaces[contextKey] + def getRootInterface() -> AppObjects: """ Returns a AppObjects instance with root privileges. This is used for initial setup and user creation. """ global _rootAppObjects - + if _rootAppObjects is None: try: # Create a temporary interface without user context to get root user tempInterface = AppObjects() - + # Get the initial user directly initialUserId = tempInterface.getInitialId(UserInDB) if not initialUserId: raise ValueError("No initial user ID found in database") - - users = tempInterface.db.getRecordset(UserInDB, recordFilter={"id": initialUserId}) + + users = tempInterface.db.getRecordset( + UserInDB, recordFilter={"id": initialUserId} + ) if not users: raise ValueError("Initial user not found in database") - + # Convert to User model user_data = users[0] - rootUser = User.parse_obj(user_data) - + rootUser = User.model_validate(user_data) + # Create root interface with the root user _rootAppObjects = AppObjects(rootUser) - + except Exception as e: logger.error(f"Error getting root user: {str(e)}") raise ValueError(f"Failed to get root user: {str(e)}") - + return _rootAppObjects diff --git a/modules/shared/attributeUtils.py b/modules/shared/attributeUtils.py index 551a557d..cb67bef9 100644 --- a/modules/shared/attributeUtils.py +++ b/modules/shared/attributeUtils.py @@ -2,52 +2,55 @@ Shared utilities for model attributes and labels. """ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Dict, Any, List, Type, Optional, Union import inspect import importlib import os from datetime import datetime + class ModelMixin: """Mixin class that provides serialization methods for Pydantic models.""" - + def to_dict(self) -> Dict[str, Any]: """ Convert a Pydantic model to a dictionary. Handles both Pydantic v1 and v2. All timestamp fields remain as float values. - + Returns: Dict[str, Any]: Dictionary representation of the model """ # Get the raw dictionary - if hasattr(self, 'model_dump'): + if hasattr(self, "model_dump"): data: Dict[str, Any] = self.model_dump() # Pydantic v2 else: data: Dict[str, Any] = self.dict() # Pydantic v1 - + # All fields (including timestamps) remain in their original format # No conversions needed - timestamps are already float - + return data - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'ModelMixin': + def from_dict(cls, data: Dict[str, Any]) -> "ModelMixin": """ Create a Pydantic model instance from a dictionary. - + Args: data: Dictionary containing the model data - + Returns: ModelMixin: New instance of the model class """ return cls(**data) + # Define the AttributeDefinition class here instead of importing it class AttributeDefinition(BaseModel, ModelMixin): """Definition of a model attribute with its metadata.""" + name: str type: str label: str @@ -64,41 +67,47 @@ class AttributeDefinition(BaseModel, ModelMixin): order: int = 0 placeholder: Optional[str] = None + # Global registry for model labels MODEL_LABELS: Dict[str, Dict[str, Dict[str, str]]] = {} + def to_dict(model: BaseModel) -> Dict[str, Any]: """ Convert a Pydantic model to a dictionary. Handles both Pydantic v1 and v2. - + Args: model: The Pydantic model instance to convert - + Returns: Dict[str, Any]: Dictionary representation of the model """ - if hasattr(model, 'model_dump'): + if hasattr(model, "model_dump"): return model.model_dump() # Pydantic v2 return model.dict() # Pydantic v1 + def from_dict(model_class: Type[BaseModel], data: Dict[str, Any]) -> BaseModel: """ Create a Pydantic model instance from a dictionary. - + Args: model_class: The Pydantic model class to instantiate data: Dictionary containing the model data - + Returns: BaseModel: New instance of the model class """ return model_class(**data) -def register_model_labels(model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]): + +def register_model_labels( + model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]] +): """ Register labels for a model's attributes and the model itself. - + Args: model_name: Name of the model class model_label: Dictionary mapping language codes to model labels @@ -106,38 +115,37 @@ def register_model_labels(model_name: str, model_label: Dict[str, str], labels: labels: Dictionary mapping attribute names to their translations e.g. {"name": {"en": "Name", "fr": "Nom"}} """ - MODEL_LABELS[model_name] = { - "model": model_label, - "attributes": labels - } + MODEL_LABELS[model_name] = {"model": model_label, "attributes": labels} + def get_model_labels(model_name: str, language: str = "en") -> Dict[str, str]: """ Get labels for a model's attributes in the specified language. - + Args: model_name: Name of the model class language: Language code (default: "en") - + Returns: Dictionary mapping attribute names to their labels in the specified language """ model_data = MODEL_LABELS.get(model_name, {}) attribute_labels = model_data.get("attributes", {}) - + return { attr: translations.get(language, translations.get("en", attr)) for attr, translations in attribute_labels.items() } + def get_model_label(model_name: str, language: str = "en") -> str: """ Get the label for a model in the specified language. - + Args: model_name: Name of the model class language: Language code (default: "en") - + Returns: Model label in the specified language, or model name if no label exists """ @@ -145,156 +153,205 @@ def get_model_label(model_name: str, language: str = "en") -> str: model_label = model_data.get("model", {}) return model_label.get(language, model_label.get("en", model_name)) -def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguage: str = "en") -> Dict[str, Any]: + +def getModelAttributeDefinitions( + modelClass: Type[BaseModel] = None, userLanguage: str = "en" +) -> Dict[str, Any]: """ Get attribute definitions for a model class. - + Args: modelClass: The model class to get attributes for userLanguage: Language code for translations (default: "en") - + Returns: Dictionary containing model label and attribute definitions """ if not modelClass: return {} - + attributes = [] model_name = modelClass.__name__ labels = get_model_labels(model_name, userLanguage) model_label = get_model_label(model_name, userLanguage) - + # Handle both Pydantic v1 and v2 - if hasattr(modelClass, 'model_fields'): # Pydantic v2 + if hasattr(modelClass, "model_fields"): # Pydantic v2 fields = modelClass.model_fields for name, field in fields.items(): # Extract frontend metadata from field info - field_info = field.field_info if hasattr(field, 'field_info') else None + field_info = field.field_info if hasattr(field, "field_info") else None # Check both direct attributes and extra field for frontend metadata frontend_type = None frontend_readonly = False frontend_required = field.is_required() frontend_options = None - + if field_info: # Try direct attributes first - frontend_type = getattr(field_info, 'frontend_type', None) - frontend_readonly = getattr(field_info, 'frontend_readonly', False) - frontend_required = getattr(field_info, 'frontend_required', frontend_required) - frontend_options = getattr(field_info, 'frontend_options', None) - + frontend_type = getattr(field_info, "frontend_type", None) + frontend_readonly = getattr(field_info, "frontend_readonly", False) + frontend_required = getattr( + field_info, "frontend_required", frontend_required + ) + frontend_options = getattr(field_info, "frontend_options", None) + # If not found, check extra field - if hasattr(field_info, 'extra') and field_info.extra: + if hasattr(field_info, "extra") and field_info.extra: if frontend_type is None: - frontend_type = field_info.extra.get('frontend_type') + frontend_type = field_info.extra.get("frontend_type") if not frontend_readonly: - frontend_readonly = field_info.extra.get('frontend_readonly', False) - if frontend_required == field.is_required(): # Only override if we didn't get it from direct attribute - frontend_required = field_info.extra.get('frontend_required', frontend_required) + frontend_readonly = field_info.extra.get( + "frontend_readonly", False + ) + if ( + frontend_required == field.is_required() + ): # Only override if we didn't get it from direct attribute + frontend_required = field_info.extra.get( + "frontend_required", frontend_required + ) if frontend_options is None: - frontend_options = field_info.extra.get('frontend_options') - + frontend_options = field_info.extra.get("frontend_options") + # Use frontend type if available, otherwise fall back to Python type - field_type = frontend_type if frontend_type else (field.annotation.__name__ if hasattr(field.annotation, "__name__") else str(field.annotation)) - - attributes.append({ - "name": name, - "type": field_type, - "required": frontend_required, - "description": field.description if hasattr(field, "description") else "", - "label": labels.get(name, name), - "placeholder": f"Please enter {labels.get(name, name)}", - "editable": not frontend_readonly, - "visible": True, - "order": len(attributes), - "readonly": frontend_readonly, - "options": frontend_options - }) + field_type = ( + frontend_type + if frontend_type + else ( + field.annotation.__name__ + if hasattr(field.annotation, "__name__") + else str(field.annotation) + ) + ) + + attributes.append( + { + "name": name, + "type": field_type, + "required": frontend_required, + "description": field.description + if hasattr(field, "description") + else "", + "label": labels.get(name, name), + "placeholder": f"Please enter {labels.get(name, name)}", + "editable": not frontend_readonly, + "visible": True, + "order": len(attributes), + "readonly": frontend_readonly, + "options": frontend_options, + } + ) else: # Pydantic v1 fields = modelClass.__fields__ for name, field in fields.items(): # Extract frontend metadata from field info - field_info = field.field_info if hasattr(field, 'field_info') else None + field_info = field.field_info if hasattr(field, "field_info") else None # Check both direct attributes and extra field for frontend metadata frontend_type = None frontend_readonly = False frontend_required = field.required frontend_options = None - + if field_info: # Try direct attributes first - frontend_type = getattr(field_info, 'frontend_type', None) - frontend_readonly = getattr(field_info, 'frontend_readonly', False) - frontend_required = getattr(field_info, 'frontend_required', frontend_required) - frontend_options = getattr(field_info, 'frontend_options', None) - + frontend_type = getattr(field_info, "frontend_type", None) + frontend_readonly = getattr(field_info, "frontend_readonly", False) + frontend_required = getattr( + field_info, "frontend_required", frontend_required + ) + frontend_options = getattr(field_info, "frontend_options", None) + # If not found, check extra field - if hasattr(field_info, 'extra') and field_info.extra: + if hasattr(field_info, "extra") and field_info.extra: if frontend_type is None: - frontend_type = field_info.extra.get('frontend_type') + frontend_type = field_info.extra.get("frontend_type") if not frontend_readonly: - frontend_readonly = field_info.extra.get('frontend_readonly', False) - if frontend_required == field.required: # Only override if we didn't get it from direct attribute - frontend_required = field_info.extra.get('frontend_required', frontend_required) + frontend_readonly = field_info.extra.get( + "frontend_readonly", False + ) + if ( + frontend_required == field.required + ): # Only override if we didn't get it from direct attribute + frontend_required = field_info.extra.get( + "frontend_required", frontend_required + ) if frontend_options is None: - frontend_options = field_info.extra.get('frontend_options') - + frontend_options = field_info.extra.get("frontend_options") + # Use frontend type if available, otherwise fall back to Python type - field_type = frontend_type if frontend_type else (field.type_.__name__ if hasattr(field.type_, "__name__") else str(field.type_)) - - attributes.append({ - "name": name, - "type": field_type, - "required": frontend_required, - "description": field.field_info.description if hasattr(field.field_info, "description") else "", - "label": labels.get(name, name), - "placeholder": f"Please enter {labels.get(name, name)}", - "editable": not frontend_readonly, - "visible": True, - "order": len(attributes), - "readonly": frontend_readonly, - "options": frontend_options - }) - - return { - "model": model_label, - "attributes": attributes - } + field_type = ( + frontend_type + if frontend_type + else ( + field.type_.__name__ + if hasattr(field.type_, "__name__") + else str(field.type_) + ) + ) + + attributes.append( + { + "name": name, + "type": field_type, + "required": frontend_required, + "description": field.field_info.description + if hasattr(field.field_info, "description") + else "", + "label": labels.get(name, name), + "placeholder": f"Please enter {labels.get(name, name)}", + "editable": not frontend_readonly, + "visible": True, + "order": len(attributes), + "readonly": frontend_readonly, + "options": frontend_options, + } + ) + + return {"model": model_label, "attributes": attributes} + def getModelClasses() -> Dict[str, Type[BaseModel]]: """ Dynamically get all model classes from all model modules. - + Returns: Dict[str, Type[BaseModel]]: Dictionary of model class names to their classes """ modelClasses = {} - + # Get the interfaces directory path - interfaces_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'interfaces') - + interfaces_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "interfaces" + ) + # Find all model files for fileName in os.listdir(interfaces_dir): - if fileName.endswith('Model.py'): + if fileName.endswith("Model.py"): # Convert fileName to module name (e.g., gatewayModel.py -> gatewayModel) module_name = fileName[:-3] - + # Import the module dynamically - module = importlib.import_module(f'modules.interfaces.{module_name}') - + module = importlib.import_module(f"modules.interfaces.{module_name}") + # Get all classes from the module for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel: + if ( + inspect.isclass(obj) + and issubclass(obj, BaseModel) + and obj != BaseModel + ): modelClasses[name] = obj - + return modelClasses + class AttributeResponse(BaseModel): """Response model for entity attributes""" + attributes: List[AttributeDefinition] - - class Config: - schema_extra = { + + model_config = ConfigDict( + json_schema_extra={ "example": { "attributes": [ { @@ -305,8 +362,9 @@ class AttributeResponse(BaseModel): "placeholder": "Please enter username", "editable": True, "visible": True, - "order": 0 + "order": 0, } ] } - } \ No newline at end of file + } + ) diff --git a/requirements.txt b/requirements.txt index f5a1a2dc..28c8bb99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ websockets==12.0 uvicorn==0.23.2 python-multipart==0.0.6 httpx==0.25.0 -pydantic==1.10.13 # Ältere Version ohne Rust-Abhängigkeit +pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility email-validator==2.0.0 # Required by Pydantic for email validation slowapi==0.1.8 # For rate limiting @@ -108,3 +108,8 @@ xyzservices>=2021.09.1 # PostgreSQL connector dependencies psycopg2-binary==2.9.9 + +## LangChain & LangGraph +langchain==0.3.27 +langgraph==0.6.8 +langchain-core==0.3.77 From 63dba85b7a30927d194f24c13c57f77f5cd79af2 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 10:29:37 +0200 Subject: [PATCH 04/29] feat: mock althaus db query tool --- .../customerTools/toolQueryAlthausDatabase.py | 208 ++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py diff --git a/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py b/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py new file mode 100644 index 00000000..72c15f15 --- /dev/null +++ b/modules/features/chatBot/chatbotTools/customerTools/toolQueryAlthausDatabase.py @@ -0,0 +1,208 @@ +"""Althaus Database Query Tool for LangGraph. + +This tool provides database query capabilities for the Althaus database +via an external REST API. Only SELECT queries are allowed. +""" + +import logging +import asyncio +import re +from typing import Annotated +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + + +async def _mock_api_call(*, sql_query: str) -> dict: + """Mock the external REST API call to Althaus database. + + Args: + sql_query: The SQL SELECT query to execute + + Returns: + A dictionary containing the query results with columns and rows + """ + # Simulate network delay + await asyncio.sleep(0.5) + + # Mock response data based on common query patterns + if "users" in sql_query.lower(): + return { + "columns": ["id", "username", "email", "created_at"], + "rows": [ + [1, "john_doe", "john@example.com", "2024-01-15"], + [2, "jane_smith", "jane@example.com", "2024-02-20"], + [3, "bob_wilson", "bob@example.com", "2024-03-10"], + ], + "row_count": 3, + } + elif "products" in sql_query.lower(): + return { + "columns": ["product_id", "name", "price", "stock"], + "rows": [ + [101, "Widget A", 29.99, 150], + [102, "Widget B", 39.99, 75], + [103, "Widget C", 19.99, 200], + ], + "row_count": 3, + } + elif "orders" in sql_query.lower(): + return { + "columns": ["order_id", "customer_id", "total", "status"], + "rows": [ + [5001, 1, 129.99, "completed"], + [5002, 2, 89.50, "pending"], + [5003, 1, 199.99, "shipped"], + ], + "row_count": 3, + } + else: + # Generic response for other queries + return { + "columns": ["id", "value", "description"], + "rows": [ + [1, "Sample 1", "First sample entry"], + [2, "Sample 2", "Second sample entry"], + ], + "row_count": 2, + } + + +def _validate_select_query(*, sql_query: str) -> tuple[bool, str]: + """Validate that the query is a SELECT statement only. + + Args: + sql_query: The SQL query to validate + + Returns: + A tuple of (is_valid, error_message) + """ + # Remove leading/trailing whitespace and convert to lowercase for checking + normalized_query = sql_query.strip().lower() + + # Check if query starts with SELECT + if not normalized_query.startswith("select"): + return False, "Query must be a SELECT statement" + + # Check for dangerous keywords that should not be in a SELECT query + dangerous_keywords = [ + "insert", + "update", + "delete", + "drop", + "create", + "alter", + "truncate", + "grant", + "revoke", + "exec", + "execute", + ] + + for keyword in dangerous_keywords: + # Use word boundary to match whole words only + if re.search(rf"\b{keyword}\b", normalized_query): + return False, f"Query contains forbidden keyword: {keyword.upper()}" + + return True, "" + + +def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str: + """Format query results into a readable string. + + Args: + columns: List of column names + rows: List of row data + row_count: Total number of rows + + Returns: + Formatted string representation of the results + """ + if row_count == 0: + return "Query executed successfully but returned no results." + + # Calculate column widths + col_widths = [len(str(col)) for col in columns] + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(str(cell))) + + # Build header + header_parts = [] + for col, width in zip(columns, col_widths): + header_parts.append(str(col).ljust(width)) + header = " | ".join(header_parts) + separator = "-" * len(header) + + # Build rows + row_lines = [] + for row in rows: + row_parts = [] + for cell, width in zip(row, col_widths): + row_parts.append(str(cell).ljust(width)) + row_lines.append(" | ".join(row_parts)) + + # Combine all parts + result_parts = [ + f"Query returned {row_count} row(s):\n", + header, + separator, + "\n".join(row_lines), + ] + + return "\n".join(result_parts) + + +@tool +async def query_althaus_database( + sql_query: Annotated[ + str, "The SQL SELECT query to execute against the Althaus database" + ], +) -> str: + """Execute a SELECT query against the Althaus database via REST API. + + Use this tool to query data from the Althaus database. Only SELECT statements + are allowed for security reasons. The query will be forwarded to an external + REST API and the results will be returned in a formatted table. + + Args: + sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1") + + Returns: + A formatted string containing the query results with columns and rows + """ + try: + # Validate the query + is_valid, error_msg = _validate_select_query(sql_query=sql_query) + if not is_valid: + logger.warning(f"Invalid query attempt: {sql_query[:100]}...") + return f"Error: {error_msg}" + + logger.info(f"Executing Althaus database query: {sql_query[:100]}...") + + # Mock the external REST API call + # In production, this would be replaced with actual REST API call: + # response = await httpx.AsyncClient().post( + # "https://api.althaus.example.com/query", + # json={"query": sql_query}, + # headers={"Authorization": f"Bearer {api_key}"} + # ) + # result = response.json() + + result = await _mock_api_call(sql_query=sql_query) + + # Format and return results + formatted_output = _format_results( + columns=result["columns"], + rows=result["rows"], + row_count=result["row_count"], + ) + + logger.info( + f"Query completed successfully, returned {result['row_count']} row(s)" + ) + return formatted_output + + except Exception as e: + logger.error(f"Error in query_althaus_database tool: {str(e)}") + return f"Error executing query: {str(e)}" From 2158b9074876abfc759b4839fc365a41e58cf2d3 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 13:03:22 +0200 Subject: [PATCH 05/29] feat: add tools registry --- modules/datamodels/datamodelWorkflow.py | 155 ++++++--- .../features/chatBot/utils/toolRegistry.py | 305 ++++++++++++++++++ .../chatBot/utils/test_toolRegistry.py | 198 ++++++++++++ 3 files changed, 618 insertions(+), 40 deletions(-) create mode 100644 modules/features/chatBot/utils/toolRegistry.py create mode 100644 tests/features/chatBot/utils/test_toolRegistry.py diff --git a/modules/datamodels/datamodelWorkflow.py b/modules/datamodels/datamodelWorkflow.py index 0ff2dcca..686144c3 100644 --- a/modules/datamodels/datamodelWorkflow.py +++ b/modules/datamodels/datamodelWorkflow.py @@ -1,5 +1,6 @@ """Workflow-related base datamodels and step/task structures.""" +from enum import Enum from typing import List, Dict, Any, Optional from pydantic import BaseModel, Field from modules.shared.attributeUtils import register_model_labels, ModelMixin @@ -7,9 +8,12 @@ from modules.shared.attributeUtils import register_model_labels, ModelMixin class ActionDocument(BaseModel, ModelMixin): """Clear document structure for action results""" + documentName: str = Field(description="Name of the document") documentData: Any = Field(description="Content/data of the document") mimeType: str = Field(description="MIME type of the document") + + register_model_labels( "ActionDocument", {"en": "Action Document", "fr": "Document d'action"}, @@ -31,16 +35,25 @@ class ActionResult(BaseModel, ModelMixin): success: bool = Field(description="Whether execution succeeded") error: Optional[str] = Field(None, description="Error message if failed") - documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs") - resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)") + documents: List[ActionDocument] = Field( + default_factory=list, description="Document outputs" + ) + resultLabel: Optional[str] = Field( + None, + description="Label for document routing (set by action handler, not by action methods)", + ) @classmethod def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult": return cls(success=True, documents=documents or []) @classmethod - def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> "ActionResult": + def isFailure( + cls, error: str, documents: List[ActionDocument] = None + ) -> "ActionResult": return cls(success=False, documents=documents or [], error=error) + + register_model_labels( "ActionResult", {"en": "Action Result", "fr": "Résultat de l'action"}, @@ -55,7 +68,9 @@ register_model_labels( class ActionSelection(BaseModel, ModelMixin): method: str = Field(description="Method to execute (e.g., web, document, ai)") - name: str = Field(description="Action name within the method (e.g., search, extract)") + name: str = Field( + description="Action name within the method (e.g., search, extract)" + ) register_model_labels( @@ -69,7 +84,9 @@ register_model_labels( class ActionParameters(BaseModel, ModelMixin): - parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to execute the selected action") + parameters: Dict[str, Any] = Field( + default_factory=dict, description="Parameters to execute the selected action" + ) register_model_labels( @@ -102,8 +119,12 @@ class Observation(BaseModel, ModelMixin): success: bool = Field(description="Action execution success flag") resultLabel: str = Field(description="Deterministic label for produced documents") documentsCount: int = Field(description="Number of produced documents") - previews: List[ObservationPreview] = Field(default_factory=list, description="Compact previews of outputs") - notes: List[str] = Field(default_factory=list, description="Short notes or key facts") + previews: List[ObservationPreview] = Field( + default_factory=list, description="Compact previews of outputs" + ) + notes: List[str] = Field( + default_factory=list, description="Short notes or key facts" + ) register_model_labels( @@ -119,7 +140,9 @@ register_model_labels( ) -class TaskStatus(str): +class TaskStatus(str, Enum): + """Task status enumeration.""" + PENDING = "pending" RUNNING = "running" COMPLETED = "completed" @@ -142,7 +165,9 @@ register_model_labels( class DocumentExchange(BaseModel, ModelMixin): documentsLabel: str = Field(description="Label for the set of documents") - documents: List[str] = Field(default_factory=list, description="List of document references") + documents: List[str] = Field( + default_factory=list, description="List of document references" + ) register_model_labels( @@ -159,16 +184,28 @@ class TaskAction(BaseModel, ModelMixin): id: str = Field(..., description="Action ID") execMethod: str = Field(..., description="Method to execute") execAction: str = Field(..., description="Action to perform") - execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters") - execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents") - expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)") - userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") + execParameters: Dict[str, Any] = Field( + default_factory=dict, description="Action parameters" + ) + execResultLabel: Optional[str] = Field( + None, description="Label for the set of result documents" + ) + expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field( + None, description="Expected document formats (optional)" + ) + userMessage: Optional[str] = Field( + None, description="User-friendly message in user's language" + ) status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status") error: Optional[str] = Field(None, description="Error message if action failed") retryCount: int = Field(default=0, description="Number of retries attempted") retryMax: int = Field(default=3, description="Maximum number of retries") - processingTime: Optional[float] = Field(None, description="Processing time in seconds") - timestamp: float = Field(..., description="When the action was executed (UTC timestamp in seconds)") + processingTime: Optional[float] = Field( + None, description="Processing time in seconds" + ) + timestamp: float = Field( + ..., description="When the action was executed (UTC timestamp in seconds)" + ) result: Optional[str] = Field(None, description="Result of the action") @@ -181,7 +218,10 @@ register_model_labels( "execAction": {"en": "Action", "fr": "Action"}, "execParameters": {"en": "Parameters", "fr": "Paramètres"}, "execResultLabel": {"en": "Result Label", "fr": "Label du résultat"}, - "expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"}, + "expectedDocumentFormats": { + "en": "Expected Document Formats", + "fr": "Formats de documents attendus", + }, "userMessage": {"en": "User Message", "fr": "Message utilisateur"}, "status": {"en": "Status", "fr": "Statut"}, "error": {"en": "Error", "fr": "Erreur"}, @@ -221,16 +261,30 @@ class TaskItem(BaseModel, ModelMixin): userInput: str = Field(..., description="User input that triggered the task") status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status") error: Optional[str] = Field(None, description="Error message if task failed") - startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)") - finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)") - actionList: List[TaskAction] = Field(default_factory=list, description="List of actions to execute") + startedAt: Optional[float] = Field( + None, description="When the task started (UTC timestamp in seconds)" + ) + finishedAt: Optional[float] = Field( + None, description="When the task finished (UTC timestamp in seconds)" + ) + actionList: List[TaskAction] = Field( + default_factory=list, description="List of actions to execute" + ) retryCount: int = Field(default=0, description="Number of retries attempted") retryMax: int = Field(default=3, description="Maximum number of retries") - rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure") - dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on") + rollbackOnFailure: bool = Field( + default=True, description="Whether to rollback on failure" + ) + dependencies: List[str] = Field( + default_factory=list, description="List of task IDs this task depends on" + ) feedback: Optional[str] = Field(None, description="Task feedback message") - processingTime: Optional[float] = Field(None, description="Total processing time in seconds") - resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values") + processingTime: Optional[float] = Field( + None, description="Total processing time in seconds" + ) + resultLabels: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Map of result labels to their values" + ) register_model_labels( @@ -258,7 +312,9 @@ class TaskStep(BaseModel, ModelMixin): dependencies: Optional[list[str]] = Field(default_factory=list) success_criteria: Optional[list[str]] = Field(default_factory=list) estimated_complexity: Optional[str] = None - userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") + userMessage: Optional[str] = Field( + None, description="User-friendly message in user's language" + ) register_model_labels( @@ -269,7 +325,10 @@ register_model_labels( "objective": {"en": "Objective", "fr": "Objectif"}, "dependencies": {"en": "Dependencies", "fr": "Dépendances"}, "success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"}, - "estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"}, + "estimated_complexity": { + "en": "Estimated Complexity", + "fr": "Complexité estimée", + }, "userMessage": {"en": "User Message", "fr": "Message utilisateur"}, }, ) @@ -278,15 +337,31 @@ register_model_labels( class TaskHandover(BaseModel, ModelMixin): taskId: str = Field(description="Target task ID") sourceTask: Optional[str] = Field(None, description="Source task ID") - inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents") - outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents") + inputDocuments: List[DocumentExchange] = Field( + default_factory=list, description="Available input documents" + ) + outputDocuments: List[DocumentExchange] = Field( + default_factory=list, description="Produced output documents" + ) context: Dict[str, Any] = Field(default_factory=dict, description="Task context") - previousResults: List[str] = Field(default_factory=list, description="Previous result summaries") - improvements: List[str] = Field(default_factory=list, description="Improvement suggestions") - workflowSummary: Optional[str] = Field(None, description="Summarized workflow context") - messageHistory: List[str] = Field(default_factory=list, description="Key message summaries") - timestamp: float = Field(..., description="When the handover was created (UTC timestamp in seconds)") - handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow") + previousResults: List[str] = Field( + default_factory=list, description="Previous result summaries" + ) + improvements: List[str] = Field( + default_factory=list, description="Improvement suggestions" + ) + workflowSummary: Optional[str] = Field( + None, description="Summarized workflow context" + ) + messageHistory: List[str] = Field( + default_factory=list, description="Key message summaries" + ) + timestamp: float = Field( + ..., description="When the handover was created (UTC timestamp in seconds)" + ) + handoverType: str = Field( + default="task", description="Type of handover: task, phase, or workflow" + ) register_model_labels( @@ -310,7 +385,7 @@ register_model_labels( class TaskContext(BaseModel, ModelMixin): task_step: TaskStep - workflow: Optional['ChatWorkflow'] = None + workflow: Optional["ChatWorkflow"] = None workflow_id: Optional[str] = None available_documents: Optional[str] = "No documents available" available_connections: Optional[list[str]] = Field(default_factory=list) @@ -358,7 +433,9 @@ class ReviewResult(BaseModel, ModelMixin): met_criteria: Optional[list[str]] = Field(default_factory=list) unmet_criteria: Optional[list[str]] = Field(default_factory=list) confidence: Optional[float] = 0.5 - userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") + userMessage: Optional[str] = Field( + None, description="User-friendly message in user's language" + ) register_model_labels( @@ -381,7 +458,9 @@ register_model_labels( class TaskPlan(BaseModel, ModelMixin): overview: str tasks: list[TaskStep] - userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan") + userMessage: Optional[str] = Field( + None, description="Overall user-friendly message for the task plan" + ) register_model_labels( @@ -393,7 +472,3 @@ register_model_labels( "userMessage": {"en": "User Message", "fr": "Message utilisateur"}, }, ) - - - - diff --git a/modules/features/chatBot/utils/toolRegistry.py b/modules/features/chatBot/utils/toolRegistry.py new file mode 100644 index 00000000..5f5d14d6 --- /dev/null +++ b/modules/features/chatBot/utils/toolRegistry.py @@ -0,0 +1,305 @@ +"""Tool registry for auto-discovering and managing chatbot tools. + +This module provides a central registry that automatically discovers all tools +in the chatbotTools directory structure and provides methods to query them. +The registry is built in-memory at startup and does not require a database. +""" + +import importlib +import inspect +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolMetadata: + """Metadata about a discovered chatbot tool. + + Attributes: + tool_id: Unique identifier (e.g., 'shared.tavily_search') + name: Function name of the tool + category: Category of the tool ('shared' or 'customer') + description: Tool description from docstring + tool_instance: The actual LangChain tool instance + module_path: Full Python module path + """ + + tool_id: str + name: str + category: str + description: str + tool_instance: BaseTool + module_path: str + + def __str__(self) -> str: + """Return a pretty-printed string representation for logging.""" + return ( + f"ToolMetadata(\n" + f" tool_id='{self.tool_id}',\n" + f" name='{self.name}',\n" + f" category='{self.category}',\n" + f" description='{self.description}',\n" + f" module_path='{self.module_path}'\n" + f")" + ) + + +class ToolRegistry: + """Central registry for all chatbot tools. + + This class discovers and catalogs all tools decorated with @tool in the + chatbotTools directory structure. Tools are automatically discovered at + initialization by scanning the filesystem. + + The registry provides methods to query tools by ID, category, or get all tools. + """ + + def __init__(self) -> None: + """Initialize an empty tool registry.""" + self._tools: Dict[str, ToolMetadata] = {} + self._initialized: bool = False + + def initialize(self) -> None: + """Discover and register all tools from the chatbotTools directory. + + This method scans both sharedTools and customerTools directories, + imports all tool*.py modules, and extracts functions decorated with @tool. + + This method is idempotent - calling it multiple times has no effect + after the first initialization. + """ + if self._initialized: + logger.debug("Tool registry already initialized, skipping") + return + + logger.info("Initializing tool registry...") + + # Get base path to chatbotTools directory + base_path = Path(__file__).parent.parent / "chatbotTools" + + if not base_path.exists(): + logger.warning(f"chatbotTools directory not found at {base_path}") + self._initialized = True + return + + # Discover tools in each category + self._discover_category( + category_path=base_path / "sharedTools", category="shared" + ) + self._discover_category( + category_path=base_path / "customerTools", category="customer" + ) + + self._initialized = True + logger.info(f"Tool registry initialized with {len(self._tools)} tools") + + def _discover_category(self, *, category_path: Path, category: str) -> None: + """Discover all tools in a specific category directory. + + Args: + category_path: Path to the category directory (sharedTools or customerTools) + category: Category name ('shared' or 'customer') + """ + if not category_path.exists(): + logger.warning(f"Category directory not found: {category_path}") + return + + logger.debug(f"Discovering tools in category: {category}") + + # Find all tool*.py files (excluding __init__.py) + tool_files = [ + f for f in category_path.glob("tool*.py") if f.name != "__init__.py" + ] + + for tool_file in tool_files: + self._import_and_register_tools( + tool_file=tool_file, category=category, category_path=category_path + ) + + logger.debug(f"Discovered {len(tool_files)} tool files in {category}") + + def _import_and_register_tools( + self, *, tool_file: Path, category: str, category_path: Path + ) -> None: + """Import a tool module and register all discovered tools. + + Args: + tool_file: Path to the tool Python file + category: Category name ('shared' or 'customer') + category_path: Path to the category directory + """ + # Construct module name + module_name = ( + f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}" + ) + + try: + # Import the module + module = importlib.import_module(module_name) + + # Find all BaseTool instances in the module + tools_found = 0 + for name, obj in inspect.getmembers(module): + if isinstance(obj, BaseTool): + self._register_tool( + tool_instance=obj, + name=name, + category=category, + module_path=module_name, + ) + tools_found += 1 + + if tools_found == 0: + logger.warning(f"No tools found in {module_name}") + else: + logger.debug(f"Loaded {tools_found} tool(s) from {module_name}") + + except ImportError as e: + logger.error( + f"Import error loading tools from {module_name}: {str(e)}. " + f"This tool will not be available." + ) + except Exception as e: + logger.error( + f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}" + ) + + def _register_tool( + self, *, tool_instance: BaseTool, name: str, category: str, module_path: str + ) -> None: + """Register a single tool in the registry. + + Args: + tool_instance: The LangChain tool instance + name: Function name of the tool + category: Category name ('shared' or 'customer') + module_path: Full Python module path + """ + tool_id = f"{category}.{name}" + + # Check for duplicate tool IDs + if tool_id in self._tools: + logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting") + + metadata = ToolMetadata( + tool_id=tool_id, + name=name, + category=category, + description=tool_instance.description or "", + tool_instance=tool_instance, + module_path=module_path, + ) + + self._tools[tool_id] = metadata + logger.debug(f"Registered tool: {tool_id}") + + def get_all_tools(self) -> List[ToolMetadata]: + """Get all registered tools. + + Returns: + List of all tool metadata objects + """ + return list(self._tools.values()) + + def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]: + """Get a specific tool by its ID. + + Args: + tool_id: The tool identifier (e.g., 'shared.tavily_search') + + Returns: + Tool metadata if found, None otherwise + """ + return self._tools.get(tool_id) + + def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]: + """Get all tools in a specific category. + + Args: + category: Category name ('shared' or 'customer') + + Returns: + List of tool metadata for the specified category + """ + return [t for t in self._tools.values() if t.category == category] + + def list_tool_ids(self) -> List[str]: + """Get a list of all registered tool IDs. + + Returns: + List of tool ID strings + """ + return list(self._tools.keys()) + + def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]: + """Get actual tool instances for a list of tool IDs. + + This is useful for filtering tools based on user permissions. + + Args: + tool_ids: List of tool IDs to retrieve + + Returns: + List of BaseTool instances for the specified IDs + """ + instances = [] + for tool_id in tool_ids: + metadata = self.get_tool(tool_id=tool_id) + if metadata: + instances.append(metadata.tool_instance) + else: + logger.warning(f"Tool ID not found in registry: {tool_id}") + return instances + + @property + def is_initialized(self) -> bool: + """Check if the registry has been initialized. + + Returns: + True if initialized, False otherwise + """ + return self._initialized + + +# Global registry instance +_registry: Optional[ToolRegistry] = None + + +def get_registry() -> ToolRegistry: + """Get the global tool registry instance. + + This function ensures the registry is initialized on first access. + Subsequent calls return the same instance. + + Returns: + The global ToolRegistry instance + """ + global _registry + + if _registry is None: + _registry = ToolRegistry() + + if not _registry.is_initialized: + _registry.initialize() + + return _registry + + +def reinitialize_registry() -> ToolRegistry: + """Force reinitialize the tool registry. + + This is useful for testing or when tools are added dynamically. + + Returns: + The reinitialized ToolRegistry instance + """ + global _registry + _registry = ToolRegistry() + _registry.initialize() + return _registry diff --git a/tests/features/chatBot/utils/test_toolRegistry.py b/tests/features/chatBot/utils/test_toolRegistry.py new file mode 100644 index 00000000..219752b7 --- /dev/null +++ b/tests/features/chatBot/utils/test_toolRegistry.py @@ -0,0 +1,198 @@ +"""Pytest tests for the tool registry. + +This module tests that the tool registry correctly discovers and catalogs +all tools in the chatbotTools directory. +""" + +import logging +import pytest +from modules.features.chatBot.utils.toolRegistry import ( + ToolMetadata, + ToolRegistry, + get_registry, + reinitialize_registry, +) +from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +class TestToolRegistry: + """Test suite for ToolRegistry class.""" + + @pytest.fixture + def registry(self) -> ToolRegistry: + """Provide a fresh registry instance for each test.""" + return reinitialize_registry() + + def test_registry_initialization(self, registry: ToolRegistry) -> None: + """Test that registry initializes correctly.""" + assert registry.is_initialized + assert isinstance(registry._tools, dict) + + def test_get_all_tools(self, registry: ToolRegistry) -> None: + """Test getting all registered tools.""" + all_tools = registry.get_all_tools() + assert isinstance(all_tools, list) + assert len(all_tools) > 0 + assert all(isinstance(tool, ToolMetadata) for tool in all_tools) + + # Log all discovered tools + logger.info(f"Found {len(all_tools)} tools in registry:") + for tool in all_tools: + logger.info(f"\n{tool}") + + def test_tool_metadata_structure(self, registry: ToolRegistry) -> None: + """Test that tool metadata has correct structure.""" + all_tools = registry.get_all_tools() + for tool in all_tools: + assert isinstance(tool.tool_id, str) + assert isinstance(tool.name, str) + assert isinstance(tool.category, str) + assert tool.category in ["shared", "customer"] + assert isinstance(tool.description, str) + assert isinstance(tool.tool_instance, BaseTool) + assert isinstance(tool.module_path, str) + + def test_list_tool_ids(self, registry: ToolRegistry) -> None: + """Test listing all tool IDs.""" + tool_ids = registry.list_tool_ids() + assert isinstance(tool_ids, list) + assert len(tool_ids) > 0 + assert all(isinstance(tool_id, str) for tool_id in tool_ids) + + # Check that tool IDs follow expected format + for tool_id in tool_ids: + assert "." in tool_id + category, name = tool_id.split(".", 1) + assert category in ["shared", "customer"] + + def test_get_specific_tool(self, registry: ToolRegistry) -> None: + """Test retrieving a specific tool by ID.""" + # Get all tool IDs first + tool_ids = registry.list_tool_ids() + if tool_ids: + # Test with first available tool + test_tool_id = tool_ids[0] + tool_metadata = registry.get_tool(tool_id=test_tool_id) + + assert tool_metadata is not None + assert isinstance(tool_metadata, ToolMetadata) + assert tool_metadata.tool_id == test_tool_id + + def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None: + """Test retrieving a tool that doesn't exist.""" + tool_metadata = registry.get_tool(tool_id="nonexistent.tool") + assert tool_metadata is None + + def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None: + """Test getting all shared tools.""" + shared_tools = registry.get_tools_by_category(category="shared") + assert isinstance(shared_tools, list) + assert all(tool.category == "shared" for tool in shared_tools) + + def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None: + """Test getting all customer tools.""" + customer_tools = registry.get_tools_by_category(category="customer") + assert isinstance(customer_tools, list) + assert all(tool.category == "customer" for tool in customer_tools) + + def test_get_tool_instances(self, registry: ToolRegistry) -> None: + """Test getting tool instances by IDs.""" + tool_ids = registry.list_tool_ids() + if len(tool_ids) >= 2: + # Test with first two tools + test_ids = tool_ids[:2] + instances = registry.get_tool_instances(tool_ids=test_ids) + + assert isinstance(instances, list) + assert len(instances) == 2 + assert all(isinstance(inst, BaseTool) for inst in instances) + + def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None: + """Test getting tool instances with some invalid IDs.""" + tool_ids = registry.list_tool_ids() + if tool_ids: + # Mix valid and invalid IDs + test_ids = [tool_ids[0], "invalid.tool"] + instances = registry.get_tool_instances(tool_ids=test_ids) + + # Should only return the valid one + assert len(instances) == 1 + assert isinstance(instances[0], BaseTool) + + def test_global_registry_singleton(self) -> None: + """Test that get_registry returns same instance.""" + registry1 = get_registry() + registry2 = get_registry() + assert registry1 is registry2 + + def test_reinitialize_registry(self) -> None: + """Test that reinitialize creates new instance.""" + registry1 = get_registry() + registry2 = reinitialize_registry() + # Should be different instances after reinitialize + assert registry1 is not registry2 + assert registry2.is_initialized + + +class TestToolDiscovery: + """Test suite for tool discovery functionality.""" + + def test_discovers_at_least_one_tool(self) -> None: + """Test that at least one tool is discovered.""" + registry = get_registry() + tool_ids = registry.list_tool_ids() + + # At least one tool should be successfully loaded + assert len(tool_ids) >= 1, "Expected at least one tool to be discovered" + + def test_query_althaus_database_if_available(self) -> None: + """Test query_althaus_database tool if it was successfully loaded.""" + registry = get_registry() + tool = registry.get_tool(tool_id="customer.query_althaus_database") + + if tool is not None: + assert tool.name == "query_althaus_database" + assert tool.category == "customer" + assert "database" in tool.description.lower() + else: + # Tool may not have loaded due to import errors - log warning + import logging + + logging.warning( + "customer.query_althaus_database tool not found - " + "may have failed to import" + ) + + def test_tavily_search_if_available(self) -> None: + """Test tavily_search tool if it was successfully loaded.""" + registry = get_registry() + tool = registry.get_tool(tool_id="shared.tavily_search") + + if tool is not None: + assert tool.name == "tavily_search" + assert tool.category == "shared" + assert "search" in tool.description.lower() + else: + # Tool may not have loaded due to import errors - log warning + import logging + + logging.warning( + "shared.tavily_search tool not found - may have failed to import" + ) + + def test_tool_ids_have_correct_format(self) -> None: + """Test that all discovered tool IDs follow the expected format.""" + registry = get_registry() + tool_ids = registry.list_tool_ids() + + for tool_id in tool_ids: + # All tool IDs should have format: category.toolname + assert "." in tool_id, f"Tool ID {tool_id} missing category separator" + category, name = tool_id.split(".", 1) + assert category in [ + "shared", + "customer", + ], f"Tool {tool_id} has invalid category: {category}" + assert len(name) > 0, f"Tool {tool_id} has empty name" From 8707203ac2a22c3f6d8aa6d107b3dfa9daf02d42 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 13:17:45 +0200 Subject: [PATCH 06/29] feat: mock chatbot permissions --- modules/features/chatBot/utils/permissions.py | 59 ++++++++++++++++++- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/modules/features/chatBot/utils/permissions.py b/modules/features/chatBot/utils/permissions.py index 45e306a5..f8e57a30 100644 --- a/modules/features/chatBot/utils/permissions.py +++ b/modules/features/chatBot/utils/permissions.py @@ -1,7 +1,60 @@ -# get_allowed_tools +"""Mock permissions module for chatbot access control. + +This module provides mock permission functions that will be replaced +with actual database-driven permissions in the future. +""" + +from datetime import datetime + +from modules.features.chatBot.utils.toolRegistry import get_registry -# get_allowed_models +# TODO: Replace these mock implementations with actual database queries -# get_system_prompt +def get_allowed_tools(*, user_id: str) -> list[str]: + """Get list of tool IDs that a user is allowed to use. + + This is a mock implementation that returns all available tools + regardless of user_id. In production, this will query the database + for user-specific permissions. + + Args: + user_id: The unique identifier of the user + + Returns: + List of tool IDs (e.g., ["shared.tavily_search", "customer.query_althaus_database"]) + """ + registry = get_registry() + return registry.list_tool_ids() + + +def get_allowed_models(*, user_id: str) -> list[str]: + """Get list of AI models that a user is allowed to use. + + This is a mock implementation that returns a fixed list of models. + In production, this will query the database for user-specific model permissions. + + Args: + user_id: The unique identifier of the user + + Returns: + List of model identifiers (e.g., ["gpt-5", "claude-4-5"]) + """ + return ["gpt-5", "claude-4-5"] + + +def get_system_prompt(*, user_id: str) -> str: + """Get the system prompt for a user's chatbot session. + + This is a mock implementation that returns a generic prompt with today's date. + In production, this will query the database for user-specific or role-specific prompts. + + Args: + user_id: The unique identifier of the user + + Returns: + The system prompt string with the current date + """ + current_date = datetime.now().strftime("%Y-%m-%d") + return f"You're a smart assistant. Today is {current_date}" From 4bfeded9d0f1760e953ef4268e1139866ab168c9 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 16:48:33 +0200 Subject: [PATCH 07/29] feat: chatbot w/ streaming basics --- app.py | 17 + modules/datamodels/datamodelChat.py | 268 ++++++++++++---- modules/datamodels/datamodelChatbot.py | 130 ++++++++ .../sharedTools/toolStreamingStatus.py | 24 ++ modules/features/chatBot/domain/__init__.py | 1 + modules/features/chatBot/domain/chatbot.py | 301 ++++++++++++++++++ .../chatBot/domain/streaming_helper.py | 239 ++++++++++++++ modules/features/chatBot/service.py | 215 +++++++++++++ .../features/chatBot/utils/checkpointer.py | 95 ++++++ modules/features/chatBot/utils/permissions.py | 31 +- modules/routes/routeChatbot.py | 147 ++++----- requirements.txt | 6 +- 12 files changed, 1315 insertions(+), 159 deletions(-) create mode 100644 modules/datamodels/datamodelChatbot.py create mode 100644 modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py create mode 100644 modules/features/chatBot/domain/__init__.py create mode 100644 modules/features/chatBot/domain/chatbot.py create mode 100644 modules/features/chatBot/domain/streaming_helper.py create mode 100644 modules/features/chatBot/service.py create mode 100644 modules/features/chatBot/utils/checkpointer.py diff --git a/app.py b/app.py index ed4e7214..23eeb645 100644 --- a/app.py +++ b/app.py @@ -240,9 +240,26 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Application is starting up") + + # Initialize LangGraph checkpointer + from modules.features.chatBot.utils.checkpointer import ( + initialize_checkpointer, + close_checkpointer, + ) + + try: + await initialize_checkpointer() + logger.info("LangGraph checkpointer initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}") + # Continue startup even if checkpointer fails to initialize + eventManager.start() yield + + # Cleanup eventManager.stop() + await close_checkpointer() logger.info("Application has been shut down") diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index a1640b5d..62fa691a 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -8,10 +8,18 @@ import uuid class ChatStat(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") - workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)") - messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)") - processingTime: Optional[float] = Field(None, description="Processing time in seconds") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) + workflowId: Optional[str] = Field( + None, description="Foreign key to workflow (for workflow stats)" + ) + messageId: Optional[str] = Field( + None, description="Foreign key to message (for message stats)" + ) + processingTime: Optional[float] = Field( + None, description="Processing time in seconds" + ) tokenCount: Optional[int] = Field(None, description="Number of tokens processed") bytesSent: Optional[int] = Field(None, description="Number of bytes sent") bytesReceived: Optional[int] = Field(None, description="Number of bytes received") @@ -37,14 +45,23 @@ register_model_labels( class ChatLog(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) workflowId: str = Field(description="Foreign key to workflow") message: str = Field(description="Log message") type: str = Field(description="Log type (info, warning, error, etc.)") - timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)") + timestamp: float = Field( + default_factory=get_utc_timestamp, + description="When the log entry was created (UTC timestamp in seconds)", + ) status: Optional[str] = Field(None, description="Status of the log entry") - progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)") - performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics") + progress: Optional[float] = Field( + None, description="Progress indicator (0.0 to 1.0)" + ) + performance: Optional[Dict[str, Any]] = Field( + None, description="Performance metrics" + ) register_model_labels( @@ -64,7 +81,9 @@ register_model_labels( class ChatDocument(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) messageId: str = Field(description="Foreign key to message") fileId: str = Field(description="Foreign key to file") fileName: str = Field(description="Name of the file") @@ -73,7 +92,9 @@ class ChatDocument(BaseModel, ModelMixin): roundNumber: Optional[int] = Field(None, description="Round number in workflow") taskNumber: Optional[int] = Field(None, description="Task number within round") actionNumber: Optional[int] = Field(None, description="Action number within task") - actionId: Optional[str] = Field(None, description="ID of the action that created this document") + actionId: Optional[str] = Field( + None, description="ID of the action that created this document" + ) register_model_labels( @@ -96,13 +117,19 @@ register_model_labels( class ContentMetadata(BaseModel, ModelMixin): size: int = Field(description="Content size in bytes") - pages: Optional[int] = Field(None, description="Number of pages for multi-page content") + pages: Optional[int] = Field( + None, description="Number of pages for multi-page content" + ) error: Optional[str] = Field(None, description="Processing error if any") width: Optional[int] = Field(None, description="Width in pixels for images/videos") - height: Optional[int] = Field(None, description="Height in pixels for images/videos") + height: Optional[int] = Field( + None, description="Height in pixels for images/videos" + ) colorMode: Optional[str] = Field(None, description="Color mode") fps: Optional[float] = Field(None, description="Frames per second for videos") - durationSec: Optional[float] = Field(None, description="Duration in seconds for media") + durationSec: Optional[float] = Field( + None, description="Duration in seconds for media" + ) mimeType: str = Field(description="MIME type of the content") base64Encoded: bool = Field(description="Whether the data is base64 encoded") @@ -144,7 +171,9 @@ register_model_labels( class ExtractedContent(BaseModel, ModelMixin): id: str = Field(description="Reference to source ChatDocument") - contents: List[ContentItem] = Field(default_factory=list, description="List of content items") + contents: List[ContentItem] = Field( + default_factory=list, description="List of content items" + ) register_model_labels( @@ -156,27 +185,53 @@ register_model_labels( }, ) + class ChatMessage(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) workflowId: str = Field(description="Foreign key to workflow") - parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading") - documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents") - documentsLabel: Optional[str] = Field(None, description="Label for the set of documents") + parentMessageId: Optional[str] = Field( + None, description="Parent message ID for threading" + ) + documents: List[ChatDocument] = Field( + default_factory=list, description="Associated documents" + ) + documentsLabel: Optional[str] = Field( + None, description="Label for the set of documents" + ) message: Optional[str] = Field(None, description="Message content") role: str = Field(description="Role of the message sender") status: str = Field(description="Status of the message (first, step, last)") - sequenceNr: int = Field(description="Sequence number of the message (set automatically)") - publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)") + sequenceNr: int = Field( + description="Sequence number of the message (set automatically)" + ) + publishedAt: float = Field( + default_factory=get_utc_timestamp, + description="When the message was published (UTC timestamp in seconds)", + ) stats: Optional[ChatStat] = Field(None, description="Statistics for this message") - success: Optional[bool] = Field(None, description="Whether the message processing was successful") - actionId: Optional[str] = Field(None, description="ID of the action that produced this message") - actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message") - actionName: Optional[str] = Field(None, description="Name of the action that produced this message") + success: Optional[bool] = Field( + None, description="Whether the message processing was successful" + ) + actionId: Optional[str] = Field( + None, description="ID of the action that produced this message" + ) + actionMethod: Optional[str] = Field( + None, description="Method of the action that produced this message" + ) + actionName: Optional[str] = Field( + None, description="Name of the action that produced this message" + ) roundNumber: Optional[int] = Field(None, description="Round number in workflow") taskNumber: Optional[int] = Field(None, description="Task number within round") actionNumber: Optional[int] = Field(None, description="Action number within task") - taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry") - actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail") + taskProgress: Optional[str] = Field( + None, description="Task progress status: pending, running, success, fail, retry" + ) + actionProgress: Optional[str] = Field( + None, description="Action progress status: pending, running, success, fail" + ) register_model_labels( @@ -208,31 +263,135 @@ register_model_labels( class ChatWorkflow(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False) - mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False) - status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[ - {"value": "running", "label": {"en": "Running", "fr": "En cours"}}, - {"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}}, - {"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}}, - {"value": "error", "label": {"en": "Error", "fr": "Erreur"}}, - ]) - name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True) - currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False) - currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False) - currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False) - totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False) - totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False) - lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False) - startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False) - logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False) - messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False) - stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False) - tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False) - workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[ - {"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}}, - {"value": "React", "label": {"en": "React", "fr": "Réactif"}}, - ]) - maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False) + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Primary key", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + mandateId: str = Field( + description="ID of the mandate this workflow belongs to", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + status: str = Field( + description="Current status of the workflow", + frontend_type="select", + frontend_readonly=False, + frontend_required=False, + frontend_options=[ + {"value": "running", "label": {"en": "Running", "fr": "En cours"}}, + {"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}}, + {"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}}, + {"value": "error", "label": {"en": "Error", "fr": "Erreur"}}, + ], + ) + name: Optional[str] = Field( + None, + description="Name of the workflow", + frontend_type="text", + frontend_readonly=False, + frontend_required=True, + ) + currentRound: int = Field( + description="Current round number", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + currentTask: int = Field( + default=0, + description="Current task number", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + currentAction: int = Field( + default=0, + description="Current action number", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + totalTasks: int = Field( + default=0, + description="Total number of tasks in the workflow", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + totalActions: int = Field( + default=0, + description="Total number of actions in the workflow", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + lastActivity: float = Field( + default_factory=get_utc_timestamp, + description="Timestamp of last activity (UTC timestamp in seconds)", + frontend_type="timestamp", + frontend_readonly=True, + frontend_required=False, + ) + startedAt: float = Field( + default_factory=get_utc_timestamp, + description="When the workflow started (UTC timestamp in seconds)", + frontend_type="timestamp", + frontend_readonly=True, + frontend_required=False, + ) + logs: List[ChatLog] = Field( + default_factory=list, + description="Workflow logs", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + messages: List[ChatMessage] = Field( + default_factory=list, + description="Messages in the workflow", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + stats: Optional[ChatStat] = Field( + None, + description="Workflow statistics", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + tasks: list = Field( + default_factory=list, + description="List of tasks in the workflow", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + workflowMode: str = Field( + default="Actionplan", + description="Workflow mode selector", + frontend_type="select", + frontend_readonly=False, + frontend_required=False, + frontend_options=[ + { + "value": "Actionplan", + "label": {"en": "Action Plan", "fr": "Plan d'actions"}, + }, + {"value": "React", "label": {"en": "React", "fr": "Réactif"}}, + ], + ) + maxSteps: int = Field( + default=5, + description="Maximum number of iterations in react mode", + frontend_type="integer", + frontend_readonly=False, + frontend_required=False, + ) register_model_labels( @@ -278,7 +437,10 @@ register_model_labels( "completed_tasks": {"en": "Completed Tasks", "fr": "Tâches terminées"}, "total_tasks": {"en": "Total Tasks", "fr": "Total des tâches"}, "execution_time": {"en": "Execution Time", "fr": "Temps d'exécution"}, - "final_results_count": {"en": "Final Results Count", "fr": "Nombre de résultats finaux"}, + "final_results_count": { + "en": "Final Results Count", + "fr": "Nombre de résultats finaux", + }, "error": {"en": "Error", "fr": "Erreur"}, "phase": {"en": "Phase", "fr": "Phase"}, }, @@ -300,5 +462,3 @@ register_model_labels( "userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"}, }, ) - - diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py new file mode 100644 index 00000000..906757b7 --- /dev/null +++ b/modules/datamodels/datamodelChatbot.py @@ -0,0 +1,130 @@ +"""Chatbot API models for request/response handling.""" + +from typing import List, Optional +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import register_model_labels, ModelMixin + + +# Chatbot API Models +class MessageItem(BaseModel, ModelMixin): + """Individual message in a thread""" + + role: str = Field(..., description="Message role (user or assistant)") + content: str = Field(..., description="Message content") + timestamp: float = Field(..., description="Message timestamp (Unix timestamp)") + + +class ChatMessageRequest(BaseModel, ModelMixin): + """Request model for posting a chat message""" + + thread_id: Optional[str] = Field( + None, description="Thread ID (creates new thread if not provided)" + ) + message: str = Field(..., description="User message content") + + +class ChatMessageResponse(BaseModel, ModelMixin): + """Response model for posting a chat message""" + + thread_id: str = Field(..., description="Thread ID") + messages: List[MessageItem] = Field(..., description="All messages in thread") + + +class ThreadSummary(BaseModel, ModelMixin): + """Summary of a chat thread for list view""" + + thread_id: str = Field(..., description="Thread ID") + created_at: float = Field(..., description="Thread creation timestamp") + last_message: str = Field(..., description="Last message content") + message_count: int = Field(..., description="Total number of messages") + + +class ThreadListResponse(BaseModel, ModelMixin): + """Response model for listing all threads""" + + threads: List[ThreadSummary] = Field(..., description="List of thread summaries") + + +class ThreadDetail(BaseModel, ModelMixin): + """Detailed view of a single thread""" + + thread_id: str = Field(..., description="Thread ID") + created_at: float = Field(..., description="Thread creation timestamp") + messages: List[MessageItem] = Field( + ..., description="All messages in chronological order" + ) + + +class DeleteResponse(BaseModel, ModelMixin): + """Response model for delete operations""" + + message: str = Field(..., description="Confirmation message") + thread_id: str = Field(..., description="Deleted thread ID") + + +# Register model labels for internationalization +register_model_labels( + "MessageItem", + {"en": "Message Item", "fr": "Élément de message"}, + { + "role": {"en": "Role", "fr": "Rôle"}, + "content": {"en": "Content", "fr": "Contenu"}, + "timestamp": {"en": "Timestamp", "fr": "Horodatage"}, + }, +) + +register_model_labels( + "ChatMessageRequest", + {"en": "Chat Message Request", "fr": "Demande de message de chat"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "message": {"en": "Message", "fr": "Message"}, + }, +) + +register_model_labels( + "ChatMessageResponse", + {"en": "Chat Message Response", "fr": "Réponse du message de chat"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "messages": {"en": "Messages", "fr": "Messages"}, + }, +) + +register_model_labels( + "ThreadSummary", + {"en": "Thread Summary", "fr": "Résumé du fil"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "created_at": {"en": "Created At", "fr": "Créé le"}, + "last_message": {"en": "Last Message", "fr": "Dernier message"}, + "message_count": {"en": "Message Count", "fr": "Nombre de messages"}, + }, +) + +register_model_labels( + "ThreadListResponse", + {"en": "Thread List Response", "fr": "Réponse de liste de fils"}, + { + "threads": {"en": "Threads", "fr": "Fils"}, + }, +) + +register_model_labels( + "ThreadDetail", + {"en": "Thread Detail", "fr": "Détail du fil"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "created_at": {"en": "Created At", "fr": "Créé le"}, + "messages": {"en": "Messages", "fr": "Messages"}, + }, +) + +register_model_labels( + "DeleteResponse", + {"en": "Delete Response", "fr": "Réponse de suppression"}, + { + "message": {"en": "Message", "fr": "Message"}, + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + }, +) diff --git a/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py b/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py new file mode 100644 index 00000000..f2587be8 --- /dev/null +++ b/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py @@ -0,0 +1,24 @@ +"""Tool for sending streaming status updates to users.""" + +from langchain_core.tools import tool + + +@tool +def send_streaming_message(message: str) -> str: + """Send a streaming message to the user to provide updates during processing. + + Use this tool to send short status updates to the user while you are working + on their request. This helps keep the user informed about what you are doing. + + Args: + message: A short message describing what you are currently doing. + Examples: "Searching database for relevant information..." + "Analyzing search results..." + "Processing your request..." + + Returns: + A confirmation that the message was sent. + """ + # This tool doesn't actually do anything - it's just for the AI to signal + # what it's doing to the frontend via the tool call mechanism + return f"Status update sent: {message}" diff --git a/modules/features/chatBot/domain/__init__.py b/modules/features/chatBot/domain/__init__.py new file mode 100644 index 00000000..abd60dca --- /dev/null +++ b/modules/features/chatBot/domain/__init__.py @@ -0,0 +1 @@ +"""Domain logic for chatbot functionality.""" diff --git a/modules/features/chatBot/domain/chatbot.py b/modules/features/chatBot/domain/chatbot.py new file mode 100644 index 00000000..56129097 --- /dev/null +++ b/modules/features/chatBot/domain/chatbot.py @@ -0,0 +1,301 @@ +"""Chatbot domain logic with LangGraph integration.""" + +from dataclasses import dataclass +from typing import Annotated, AsyncIterator, Any +import logging + +from pydantic import BaseModel +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + SystemMessage, + trim_messages, +) +from langgraph.graph.message import add_messages +from langgraph.graph import StateGraph, START, END +from langgraph.graph.state import CompiledStateGraph +from langgraph.prebuilt import ToolNode +from langchain_anthropic import ChatAnthropic + +from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper +from modules.features.chatBot.utils.toolRegistry import get_registry +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + + +class ChatState(BaseModel): + """Represents the state of a chat session.""" + + messages: Annotated[list[BaseMessage], add_messages] + + +def get_langchain_model(*, model_name: str) -> ChatAnthropic: + """Map permission model names to LangChain ChatAnthropic models. + + Args: + model_name: The model name from permissions (e.g., "claude_4_5") + + Returns: + Configured ChatAnthropic instance + + Raises: + ValueError: If the model name is not supported + """ + # Model name mapping + model_mapping = { + "claude_4_5": "claude-4-5-sonnet", + # Add more mappings as needed + } + + anthropic_model = model_mapping.get(model_name) + if not anthropic_model: + logger.warning( + f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet" + ) + anthropic_model = "claude-4-5-sonnet" + + return ChatAnthropic( + model=anthropic_model, + api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"), + temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)), + max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)), + ) + + +@dataclass +class Chatbot: + """Represents a chatbot with LangGraph integration.""" + + model: Any + memory: Any + app: Any = None + system_prompt: str = "You are a helpful assistant." + context_window_size: int = 100000 + + @classmethod + async def create( + cls, + *, + model: Any, + memory: Any, + system_prompt: str, + tools: list, + context_window_size: int = 100000, + ) -> "Chatbot": + """Factory method to create and configure a Chatbot instance. + + Args: + model: The chat model to use. + memory: The chat memory checkpointer to use. + system_prompt: The system prompt to initialize the chatbot. + tools: List of LangChain tools the chatbot can use. + context_window_size: Maximum tokens for context window. + + Returns: + A configured Chatbot instance. + """ + instance = cls( + model=model, + memory=memory, + system_prompt=system_prompt, + context_window_size=context_window_size, + ) + instance.app = instance._build_app(memory=memory, tools=tools) + return instance + + def _build_app( + self, *, memory: Any, tools: list + ) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]: + """Builds the chatbot application workflow using LangGraph. + + Args: + memory: The chat memory checkpointer to use. + tools: The list of tools the chatbot can use. + + Returns: + A compiled state graph representing the chatbot application. + """ + llm_with_tools = self.model.bind_tools(tools=tools) + + def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]: + """Selects a window of messages that fit within the context window size. + + Args: + msgs: The list of messages to select from. + + Returns: + A list of messages that fit within the context window size. + """ + + def approx_counter(items: list[BaseMessage]) -> int: + """Approximate token counter for messages. + + Args: + items: List of messages to count tokens for. + + Returns: + Approximate number of tokens in the messages. + """ + return sum(len(getattr(m, "content", "") or "") for m in items) + + return trim_messages( + msgs, + strategy="last", + token_counter=approx_counter, + max_tokens=self.context_window_size, + start_on="human", + end_on=("human", "tool"), + include_system=True, + ) + + def agent_node(state: ChatState) -> dict: + """Agent node for the chatbot workflow. + + Args: + state: The current chat state. + + Returns: + The updated chat state after processing. + """ + # Select the message window to fit in context (trim if needed) + window = select_window(state.messages) + + # Ensure the system prompt is present at the start + if not window or not isinstance(window[0], SystemMessage): + window = [SystemMessage(content=self.system_prompt)] + window + + # Call the LLM with tools + response = llm_with_tools.invoke(window) + + # Return the new state + return {"messages": [response]} + + def should_continue(state: ChatState) -> str: + """Determines whether to continue the workflow or end it. + + This conditional edge is called after the agent node to decide + whether to continue to the tools node (if the last message contains + tool calls) or to end the workflow (if no tool calls are present). + + Args: + state: The current chat state. + + Returns: + The next node to transition to ("tools" or END). + """ + # Get the last message + last_message = state.messages[-1] + + # Check if the last message contains tool calls + # If so, continue to the tools node; otherwise, end the workflow + return "tools" if getattr(last_message, "tool_calls", None) else END + + # Compose the workflow + workflow = StateGraph(ChatState) + workflow.add_node("agent", agent_node) + workflow.add_node("tools", ToolNode(tools=tools)) + workflow.add_edge(START, "agent") + workflow.add_conditional_edges("agent", should_continue) + workflow.add_edge("tools", "agent") + return workflow.compile(checkpointer=memory) + + async def chat( + self, *, message: str, chat_id: str = "default" + ) -> list[BaseMessage]: + """Processes a chat message and returns the chat history. + + Args: + message: The user message to process. + chat_id: The chat thread ID. + + Returns: + The list of messages in the chat history. + """ + # Set the right thread ID for memory + config = {"configurable": {"thread_id": chat_id}} + + # Single-turn chat (non-streaming) + result = await self.app.ainvoke( + {"messages": [HumanMessage(content=message)]}, config=config + ) + + # Extract and return the messages from the result + return result["messages"] + + async def stream_events( + self, *, message: str, chat_id: str = "default" + ) -> AsyncIterator[dict]: + """Stream UI-focused events using astream_events v2. + + Args: + message: The user message to process. + chat_id: Logical thread identifier; forwarded in the runnable config so + memory and tools are scoped per thread. + + Yields: + dict: One of: + - ``{"type": "status", "label": str}`` for short progress updates. + - ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}`` + where ``chat_history`` only includes ``user``/``assistant`` roles. + - ``{"type": "error", "message": str}`` if an exception occurs. + """ + # Thread-aware config for LangGraph/LangChain + config = {"configurable": {"thread_id": chat_id}} + + def _is_root(ev: dict) -> bool: + """Return True if the event is from the root run (v2: empty parent_ids).""" + return not ev.get("parent_ids") + + try: + async for event in self.app.astream_events( + {"messages": [HumanMessage(content=message)]}, + config=config, + version="v2", + ): + etype = event.get("event") + ename = event.get("name") or "" + edata = event.get("data") or {} + + # Stream human-readable progress via the special send_streaming_message tool + if etype == "on_tool_start" and ename == "send_streaming_message": + tool_in = edata.get("input") or {} + msg = tool_in.get("message") + if isinstance(msg, str) and msg.strip(): + yield {"type": "status", "label": msg.strip()} + continue + + # Emit the final payload when the root run finishes + if etype == "on_chain_end" and _is_root(event): + output_obj = edata.get("output") + + # Extract message list from the graph's final output + final_msgs = ChatStreamingHelper.extract_messages_from_output( + output_obj=output_obj + ) + + # Normalize for the frontend (only user/assistant with text content) + chat_history_payload: list[dict] = [] + for m in final_msgs: + if isinstance(m, BaseMessage): + d = ChatStreamingHelper.message_to_dict(msg=m) + elif isinstance(m, dict): + d = ChatStreamingHelper.dict_message_to_dict(obj=m) + else: + continue + if d.get("role") in ("user", "assistant") and d.get("content"): + chat_history_payload.append(d) + + yield { + "type": "final", + "response": { + "thread": chat_id, + "chat_history": chat_history_payload, + }, + } + return + + except Exception as exc: + # Emit a single error envelope and end the stream + logger.error(f"Error in stream_events: {str(exc)}", exc_info=True) + yield {"type": "error", "message": f"Error processing request: {exc}"} diff --git a/modules/features/chatBot/domain/streaming_helper.py b/modules/features/chatBot/domain/streaming_helper.py new file mode 100644 index 00000000..f8c73b45 --- /dev/null +++ b/modules/features/chatBot/domain/streaming_helper.py @@ -0,0 +1,239 @@ +"""Streaming helper utilities for chat message processing and normalization.""" + +from typing import Any, Dict, List, Literal, Mapping, Optional + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + +Role = Literal["user", "assistant", "system", "tool"] + + +class ChatStreamingHelper: + """Pure helper methods for streaming and message normalization. + + This class provides static utility methods for converting between different + message formats, extracting content, and normalizing message structures + for streaming chat applications. + """ + + @staticmethod + def role_from_message(*, msg: BaseMessage) -> Role: + """Extract the role from a BaseMessage instance. + + Args: + msg: The BaseMessage instance to extract the role from. + + Returns: + The role as a string literal: "user", "assistant", "system", or "tool". + Defaults to "assistant" if the message type is not recognized. + + Examples: + >>> from langchain_core.messages import HumanMessage + >>> msg = HumanMessage(content="Hello") + >>> ChatStreamingHelper.role_from_message(msg=msg) + 'user' + """ + if isinstance(msg, HumanMessage): + return "user" + if isinstance(msg, AIMessage): + return "assistant" + if isinstance(msg, SystemMessage): + return "system" + if isinstance(msg, ToolMessage): + return "tool" + return getattr(msg, "role", "assistant") + + @staticmethod + def flatten_content(*, content: Any) -> str: + """Convert complex content structures to plain text. + + This method handles various content formats including strings, lists of + content parts, and dictionaries with text fields. It's designed to + normalize content from different message sources into a consistent + plain text format. + + Args: + content: The content to flatten. Can be: + - str: Returned as-is after stripping whitespace + - list: Each item processed and joined with newlines + - dict: Text extracted from "text" or "content" fields + - None: Returns empty string + - Any other type: Converted to string + + Returns: + The flattened content as a plain text string with whitespace stripped. + + Examples: + >>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}] + >>> ChatStreamingHelper.flatten_content(content=content) + 'Hello + nworld' + + >>> content = {"text": "Simple message"} + >>> ChatStreamingHelper.flatten_content(content=content) + 'Simple message' + """ + if content is None: + return "" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict): + if "text" in part and isinstance(part["text"], str): + parts.append(part["text"]) + elif part.get("type") == "text" and isinstance( + part.get("text"), str + ): + parts.append(part["text"]) + elif "content" in part and isinstance(part["content"], str): + parts.append(part["content"]) + else: + # Fallback for unknown dictionary structures + val = part.get("value") + if isinstance(val, str): + parts.append(val) + else: + parts.append(str(part)) + return "\n".join(p.strip() for p in parts if p is not None) + if isinstance(content, dict): + if "text" in content and isinstance(content["text"], str): + return content["text"].strip() + if "content" in content and isinstance(content["content"], str): + return content["content"].strip() + return str(content).strip() + + @staticmethod + def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]: + """Convert a BaseMessage instance to a dictionary for streaming output. + + This method normalizes BaseMessage instances into a consistent dictionary + format suitable for JSON serialization and streaming to clients. + + Args: + msg: The BaseMessage instance to convert. + + Returns: + A dictionary containing: + - "role": The message role (user, assistant, system, tool) + - "content": The flattened message content as plain text + - "tool_calls": Tool calls if present (optional) + - "name": Message name if present (optional) + + Examples: + >>> from langchain_core.messages import HumanMessage + >>> msg = HumanMessage(content="Hello there") + >>> result = ChatStreamingHelper.message_to_dict(msg=msg) + >>> result["role"] + 'user' + >>> result["content"] + 'Hello there' + """ + payload: Dict[str, Any] = { + "role": ChatStreamingHelper.role_from_message(msg=msg), + "content": ChatStreamingHelper.flatten_content( + content=getattr(msg, "content", "") + ), + } + tool_calls = getattr(msg, "tool_calls", None) + if tool_calls: + payload["tool_calls"] = tool_calls + name = getattr(msg, "name", None) + if name: + payload["name"] = name + return payload + + @staticmethod + def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]: + """Convert a dictionary-shaped message to a normalized dictionary. + + This method handles messages that come from serialized state and are + represented as dictionaries rather than BaseMessage instances. It + normalizes various dictionary formats into a consistent structure. + + Args: + obj: The dictionary-shaped message to convert. Expected to contain + fields like "role", "type", "content", "text", etc. + + Returns: + A normalized dictionary containing: + - "role": The message role (user, assistant, system, tool) + - "content": The flattened message content as plain text + - "tool_calls": Tool calls if present (optional) + - "name": Message name if present (optional) + + Examples: + >>> obj = {"type": "human", "content": "Hello"} + >>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj) + >>> result["role"] + 'user' + >>> result["content"] + 'Hello' + """ + role: Optional[str] = obj.get("role") + if not role: + # Handle alternative type field mappings + typ = obj.get("type") + if typ in ("human", "user"): + role = "user" + elif typ in ("ai", "assistant"): + role = "assistant" + elif typ in ("system",): + role = "system" + elif typ in ("tool", "function"): + role = "tool" + + content = obj.get("content") + if content is None and "text" in obj: + content = obj["text"] + + out: Dict[str, Any] = { + "role": role or "assistant", + "content": ChatStreamingHelper.flatten_content(content=content), + } + if "tool_calls" in obj: + out["tool_calls"] = obj["tool_calls"] + if obj.get("name"): + out["name"] = obj["name"] + return out + + @staticmethod + def extract_messages_from_output(*, output_obj: Any) -> List[Any]: + """Extract messages from LangGraph output objects. + + This method handles various output formats from LangGraph execution, + extracting the messages list from different possible structures. + + Args: + output_obj: The output object from LangGraph execution. Can be: + - An object with a "messages" attribute + - A dictionary with a "messages" key + - Any other object (returns empty list) + + Returns: + A list of extracted messages, or an empty list if no messages + are found or if the output object is None. + + Examples: + >>> output = {"messages": [{"role": "user", "content": "Hello"}]} + >>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output) + >>> len(messages) + 1 + """ + if output_obj is None: + return [] + + # Try to parse dicts first + if isinstance(output_obj, dict): + msgs = output_obj.get("messages") + return msgs if isinstance(msgs, list) else [] + + # Then try to get messages attribute + msgs = getattr(output_obj, "messages", None) + return msgs if isinstance(msgs, list) else [] diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py new file mode 100644 index 00000000..1899d8e9 --- /dev/null +++ b/modules/features/chatBot/service.py @@ -0,0 +1,215 @@ +"""Service layer for chatbot functionality.""" + +import json +import logging +from typing import AsyncIterator, List + +from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model +from modules.features.chatBot.utils.checkpointer import get_checkpointer +from modules.features.chatBot.utils.toolRegistry import get_registry +from modules.features.chatBot.utils import permissions +from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse +from modules.datamodels.datamodelUam import User + +from langchain_core.messages import HumanMessage, AIMessage +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + + +async def post_message( + *, + thread_id: str, + message: str, + user: User, +) -> ChatMessageResponse: + """Post a chat message to the chatbot and return the response. + + Args: + thread_id: The unique identifier for the chat thread. + message: The content of the chat message. + user: The current user. + + Returns: + The response containing the full chat message history and thread ID. + """ + logger.info(f"User {user.id} posted message to thread {thread_id}") + + # Get user permissions + tool_ids = permissions.get_chatbot_tools(user_id=user.id) + if not tool_ids: + raise ValueError("User does not have permission to use any chatbot tools") + + model_name = permissions.get_chatbot_model(user_id=user.id) + system_prompt = permissions.get_system_prompt(user_id=user.id) + + # Get tools from registry + registry = get_registry() + tools = registry.get_tool_instances(tool_ids=tool_ids) + + # Get model and checkpointer + model = get_langchain_model(model_name=model_name) + checkpointer = get_checkpointer() + + # Get context window size from config + context_window_size = int( + APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) + ) + + # Create chatbot instance + chatbot = await Chatbot.create( + model=model, + memory=checkpointer, + system_prompt=system_prompt, + tools=tools, + context_window_size=context_window_size, + ) + + # Send message to chatbot + response = await chatbot.chat(message=message, chat_id=thread_id) + + # Parse the response to the correct format + messages = [] + for msg in response: + # Determine the role of the message + if isinstance(msg, HumanMessage): + role = "user" + elif isinstance(msg, AIMessage): + role = "assistant" + else: + continue # Skip any other message types + + # Skip messages that are structured content, such as tool calls + if not isinstance(msg.content, str): + continue + + # Append message to chat history + item = MessageItem( + role=role, + content=msg.content.strip(), + timestamp=0.0, # TODO: Add proper timestamp handling + ) + messages.append(item) + + return ChatMessageResponse(thread_id=thread_id, messages=messages) + + +async def post_message_stream( + *, + thread_id: str, + message: str, + user: User, +) -> AsyncIterator[str]: + """Post a chat message to the chatbot and stream progress updates (SSE). + + Args: + thread_id: The unique identifier for the chat thread. + message: The content of the chat message. + user: The current user. + + Yields: + Server-Sent Events formatted strings containing status updates and final response. + """ + logger.info(f"User {user.id} streaming message to thread {thread_id}") + + try: + # Get user permissions + tool_ids = permissions.get_chatbot_tools(user_id=user.id) + if not tool_ids: + yield ( + "data: " + + json.dumps( + { + "type": "error", + "message": "User does not have permission to use any chatbot tools", + } + ) + + "\n\n" + ) + return + + model_name = permissions.get_chatbot_model(user_id=user.id) + system_prompt = permissions.get_system_prompt(user_id=user.id) + + # Get tools from registry + registry = get_registry() + tools = registry.get_tool_instances(tool_ids=tool_ids) + + # Get model and checkpointer + model = get_langchain_model(model_name=model_name) + checkpointer = get_checkpointer() + + # Get context window size from config + context_window_size = int( + APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) + ) + + # Create chatbot instance + chatbot = await Chatbot.create( + model=model, + memory=checkpointer, + system_prompt=system_prompt, + tools=tools, + context_window_size=context_window_size, + ) + + # Stream events from chatbot + async for event in chatbot.stream_events(message=message, chat_id=thread_id): + etype = event.get("type") + + # Forward status updates + if etype == "status": + yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n" + continue + + # Forward final response + if etype == "final": + response_from_event = event.get("response") or {} + + # Use the chat history from the final event (already normalized by stream_events) + chat_history_payload = response_from_event.get("chat_history", []) + if isinstance(chat_history_payload, list): + # Convert to MessageItem format + items: List[MessageItem] = [] + for it in chat_history_payload: + role = it.get("role") + content = it.get("content", "") + if role in ("user", "assistant") and content: + items.append( + MessageItem( + role=role, + content=content, + timestamp=0.0, # TODO: Add proper timestamp handling + ) + ) + + response = ChatMessageResponse(thread_id=thread_id, messages=items) + # Yield the final response and exit + yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" + return + else: + # Unexpected payload format - log warning and return empty history + logger.warning( + f"Unexpected chat_history format in final event: {type(chat_history_payload)}" + ) + response = ChatMessageResponse(thread_id=thread_id, messages=[]) + yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" + return + + # Forward error events + if etype == "error": + yield f"data: {json.dumps(event)}\n\n" + return + + except Exception as e: + logger.error(f"Error in streaming chat: {str(e)}", exc_info=True) + yield ( + "data: " + + json.dumps( + { + "type": "error", + "message": "An error occurred while processing your request.", + } + ) + + "\n\n" + ) diff --git a/modules/features/chatBot/utils/checkpointer.py b/modules/features/chatBot/utils/checkpointer.py new file mode 100644 index 00000000..0aebbda6 --- /dev/null +++ b/modules/features/chatBot/utils/checkpointer.py @@ -0,0 +1,95 @@ +"""PostgreSQL checkpointer utilities for LangGraph memory.""" + +import logging +from typing import Optional + +from langgraph.checkpoint.postgres import PostgresSaver +from psycopg_pool import AsyncConnectionPool +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + +# Global checkpointer instance +_checkpointer_instance: Optional[PostgresSaver] = None +_connection_pool: Optional[AsyncConnectionPool] = None + + +async def initialize_checkpointer() -> None: + """Initialize the PostgreSQL checkpointer for LangGraph. + + This should be called during application startup. + Creates a connection pool and PostgresSaver instance. + """ + global _checkpointer_instance, _connection_pool + + if _checkpointer_instance is not None: + logger.info("Checkpointer already initialized") + return + + try: + # Get database configuration from environment + host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost") + database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat") + user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev") + password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET") + port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432") + + # Build connection string + connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}" + + # Create async connection pool + _connection_pool = AsyncConnectionPool( + conninfo=connection_string, + min_size=2, + max_size=10, + ) + + # Initialize the connection pool + await _connection_pool.open() + + # Create PostgresSaver with the pool + _checkpointer_instance = PostgresSaver(_connection_pool) + + # Setup the checkpointer (creates tables if needed) + await _checkpointer_instance.setup() + + logger.info("PostgreSQL checkpointer initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}") + raise + + +async def close_checkpointer() -> None: + """Close the checkpointer and connection pool. + + This should be called during application shutdown. + """ + global _checkpointer_instance, _connection_pool + + if _connection_pool is not None: + try: + await _connection_pool.close() + logger.info("PostgreSQL checkpointer connection pool closed") + except Exception as e: + logger.error(f"Error closing checkpointer connection pool: {str(e)}") + + _checkpointer_instance = None + _connection_pool = None + + +def get_checkpointer() -> PostgresSaver: + """Get the global PostgreSQL checkpointer instance. + + Returns: + The initialized PostgresSaver instance + + Raises: + RuntimeError: If checkpointer is not initialized + """ + if _checkpointer_instance is None: + raise RuntimeError( + "PostgreSQL checkpointer not initialized. " + "Call initialize_checkpointer() during application startup." + ) + return _checkpointer_instance diff --git a/modules/features/chatBot/utils/permissions.py b/modules/features/chatBot/utils/permissions.py index f8e57a30..d2fb4d65 100644 --- a/modules/features/chatBot/utils/permissions.py +++ b/modules/features/chatBot/utils/permissions.py @@ -12,36 +12,15 @@ from modules.features.chatBot.utils.toolRegistry import get_registry # TODO: Replace these mock implementations with actual database queries -def get_allowed_tools(*, user_id: str) -> list[str]: - """Get list of tool IDs that a user is allowed to use. - - This is a mock implementation that returns all available tools - regardless of user_id. In production, this will query the database - for user-specific permissions. - - Args: - user_id: The unique identifier of the user - - Returns: - List of tool IDs (e.g., ["shared.tavily_search", "customer.query_althaus_database"]) - """ +def get_chatbot_tools(*, user_id: str) -> list[str]: + """Get list of tool IDs that the chatbot can use for a given user.""" registry = get_registry() return registry.list_tool_ids() -def get_allowed_models(*, user_id: str) -> list[str]: - """Get list of AI models that a user is allowed to use. - - This is a mock implementation that returns a fixed list of models. - In production, this will query the database for user-specific model permissions. - - Args: - user_id: The unique identifier of the user - - Returns: - List of model identifiers (e.g., ["gpt-5", "claude-4-5"]) - """ - return ["gpt-5", "claude-4-5"] +def get_chatbot_model(*, user_id: str) -> str: + """Gets the chatbot model(s) a user is allowed to use.""" + return "claude_4_5" def get_system_prompt(*, user_id: str) -> str: diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index ff5fa9f4..01617582 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -1,13 +1,23 @@ -from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, HTTPException, status from fastapi.requests import Request +from fastapi.responses import StreamingResponse from typing import Any, Dict, List, Optional from datetime import datetime import logging import uuid from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelChatbot import ( + ChatMessageRequest, + MessageItem, + ChatMessageResponse, + ThreadSummary, + ThreadListResponse, + ThreadDetail, + DeleteResponse, +) from modules.security.auth import getCurrentUser, limiter +from modules.features.chatBot import service as chat_service logger = logging.getLogger(__name__) @@ -17,68 +27,53 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) -# --- Pydantic models for requests and responses --- - - -class ChatMessageRequest(BaseModel): - """Request model for posting a chat message""" - - thread_id: Optional[str] = Field( - None, description="Thread ID (creates new thread if not provided)" - ) - message: str = Field(..., description="User message content") - - -class MessageItem(BaseModel): - """Individual message in a thread""" - - role: str = Field(..., description="Message role (user or assistant)") - content: str = Field(..., description="Message content") - timestamp: float = Field(..., description="Message timestamp (Unix timestamp)") - - -class ChatMessageResponse(BaseModel): - """Response model for posting a chat message""" - - thread_id: str = Field(..., description="Thread ID") - messages: List[MessageItem] = Field(..., description="All messages in thread") - - -class ThreadSummary(BaseModel): - """Summary of a chat thread for list view""" - - thread_id: str = Field(..., description="Thread ID") - created_at: float = Field(..., description="Thread creation timestamp") - last_message: str = Field(..., description="Last message content") - message_count: int = Field(..., description="Total number of messages") - - -class ThreadListResponse(BaseModel): - """Response model for listing all threads""" - - threads: List[ThreadSummary] = Field(..., description="List of thread summaries") - - -class ThreadDetail(BaseModel): - """Detailed view of a single thread""" - - thread_id: str = Field(..., description="Thread ID") - created_at: float = Field(..., description="Thread creation timestamp") - messages: List[MessageItem] = Field( - ..., description="All messages in chronological order" - ) - - -class DeleteResponse(BaseModel): - """Response model for delete operations""" - - message: str = Field(..., description="Confirmation message") - thread_id: str = Field(..., description="Deleted thread ID") - # --- Actual endpoints for chatbot --- +@router.post("/message/stream") +@limiter.limit("30/minute") +async def post_chat_message_stream( + *, + request: Request, + message_request: ChatMessageRequest, + currentUser: User = Depends(getCurrentUser), +) -> StreamingResponse: + """ + Post a message to a chat thread with streaming progress updates. + Creates a new thread if thread_id is not provided. + + Returns Server-Sent Events (SSE) stream with status updates and final response. + """ + try: + # Generate or use existing thread_id + thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" + + logger.info( + f"User {currentUser.id} posted streaming message to thread {thread_id}" + ) + + return StreamingResponse( + chat_service.post_message_stream( + thread_id=thread_id, + message=message_request.message, + user=currentUser, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + except Exception as e: + logger.error(f"Error posting chat message: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to post message: {str(e)}", + ) + + @router.post("/message", response_model=ChatMessageResponse) @limiter.limit("30/minute") async def post_chat_message( @@ -88,35 +83,31 @@ async def post_chat_message( currentUser: User = Depends(getCurrentUser), ) -> ChatMessageResponse: """ - Post a message to a chat thread and get assistant response. + Post a message to a chat thread and get assistant response (non-streaming). Creates a new thread if thread_id is not provided. - This endpoint will later be connected to LangGraph's checkpointer. + For streaming updates, use the /message/stream endpoint instead. """ try: # Generate or use existing thread_id thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" - # Get current timestamp - current_time = datetime.now().timestamp() - - # Create dummy message history - # In production, this will fetch from LangGraph's checkpointer - messages = [ - MessageItem( - role="user", content=message_request.message, timestamp=current_time - ), - MessageItem( - role="assistant", - content=f"Echo: {message_request.message} (This is a dummy response. LangGraph integration pending.)", - timestamp=current_time + 0.5, - ), - ] - logger.info(f"User {currentUser.id} posted message to thread {thread_id}") - return ChatMessageResponse(thread_id=thread_id, messages=messages) + response = await chat_service.post_message( + thread_id=thread_id, + message=message_request.message, + user=currentUser, + ) + return response + + except ValueError as e: + logger.error(f"Permission error: {str(e)}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) except Exception as e: logger.error(f"Error posting chat message: {str(e)}") raise HTTPException( diff --git a/requirements.txt b/requirements.txt index 28c8bb99..2378ca97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ fastapi==0.104.1 websockets==12.0 uvicorn==0.23.2 python-multipart==0.0.6 -httpx==0.25.0 +httpx>=0.25.2 pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility email-validator==2.0.0 # Required by Pydantic for email validation slowapi==0.1.8 # For rate limiting @@ -113,3 +113,7 @@ psycopg2-binary==2.9.9 langchain==0.3.27 langgraph==0.6.8 langchain-core==0.3.77 +langchain-anthropic==0.3.1 # For Claude models +psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer) +psycopg-pool==3.2.1 # Connection pooling for PostgreSQL +langgraph-checkpoint-postgres==2.0.24 From 37f01a215603e50a85fb2c543d76afcbeea5ed1d Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 16:51:43 +0200 Subject: [PATCH 08/29] chore: fix pydantic v2 issue --- modules/connectors/connectorDbPostgre.py | 589 ++++++++++++++--------- 1 file changed, 362 insertions(+), 227 deletions(-) diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index c17fa2c3..236cb796 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -18,101 +18,140 @@ logger = logging.getLogger(__name__) # No mapping needed - table name = Pydantic model name exactly + class SystemTable(BaseModel, ModelMixin): """Data model for system table entries""" + table_name: str = Field( description="Name of the table", frontend_type="text", frontend_readonly=True, - frontend_required=True + frontend_required=True, ) initial_id: Optional[str] = Field( default=None, description="Initial ID for the table", frontend_type="text", frontend_readonly=True, - frontend_required=False + frontend_required=False, ) + def _get_model_fields(model_class) -> Dict[str, str]: """Get all fields from Pydantic model and map to SQL types.""" - if not hasattr(model_class, '__fields__'): + # Pydantic v2 uses model_fields instead of __fields__ + if hasattr(model_class, "model_fields"): + model_fields = model_class.model_fields + elif hasattr(model_class, "__fields__"): + model_fields = model_class.__fields__ + else: return {} - + fields = {} - for field_name, field_info in model_class.__fields__.items(): - field_type = field_info.type_ - + for field_name, field_info in model_fields.items(): + # Pydantic v2 uses annotation instead of type_ + field_type = ( + field_info.annotation + if hasattr(field_info, "annotation") + else field_info.type_ + ) + # Check for JSONB fields (Dict, List, or complex types) - if (field_type == dict or - field_type == list or - (hasattr(field_type, '__origin__') and field_type.__origin__ in (dict, list)) or - field_name in ['execParameters', 'expectedDocumentFormats', 'resultDocuments', 'logs', 'messages', 'stats', 'tasks']): - fields[field_name] = 'JSONB' + if ( + field_type == dict + or field_type == list + or ( + hasattr(field_type, "__origin__") + and field_type.__origin__ in (dict, list) + ) + or field_name + in [ + "execParameters", + "expectedDocumentFormats", + "resultDocuments", + "logs", + "messages", + "stats", + "tasks", + ] + ): + fields[field_name] = "JSONB" # Simple type mapping - elif field_type in (str, type(None)) or (get_origin(field_type) is Union and type(None) in get_args(field_type)): - fields[field_name] = 'TEXT' + elif field_type in (str, type(None)) or ( + get_origin(field_type) is Union and type(None) in get_args(field_type) + ): + fields[field_name] = "TEXT" elif field_type == int: - fields[field_name] = 'INTEGER' + fields[field_name] = "INTEGER" elif field_type == float: - fields[field_name] = 'DOUBLE PRECISION' + fields[field_name] = "DOUBLE PRECISION" elif field_type == bool: - fields[field_name] = 'BOOLEAN' + fields[field_name] = "BOOLEAN" else: - fields[field_name] = 'TEXT' # Default to TEXT - + fields[field_name] = "TEXT" # Default to TEXT + return fields + # No caching needed with proper database + class DatabaseConnector: """ A connector for PostgreSQL-based data storage. Provides generic database operations without user/mandate filtering. Uses PostgreSQL with JSONB columns for flexible data storage. """ - def __init__(self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None): + + def __init__( + self, + dbHost: str, + dbDatabase: str, + dbUser: str = None, + dbPassword: str = None, + dbPort: int = None, + userId: str = None, + ): # Store the input parameters self.dbHost = dbHost self.dbDatabase = dbDatabase self.dbUser = dbUser self.dbPassword = dbPassword self.dbPort = dbPort - + # Set userId (default to empty string if None) self.userId = userId if userId is not None else "" - + # Initialize database system first (creates database if needed) self.connection = None self.initDbSystem() - + # No caching needed with proper database - PostgreSQL handles performance - + # Thread safety self._lock = threading.Lock() - + # Initialize system table self._systemTableName = "_system" self._initializeSystemTable() - - + def initDbSystem(self): """Initialize the database system - creates database and tables.""" try: # Create database if it doesn't exist self._create_database_if_not_exists() - + # Create tables self._create_tables() - + # Establish connection to the database self._connect() - + logger.info("PostgreSQL database system initialized successfully") except Exception as e: logger.error(f"FATAL ERROR: Database system initialization failed: {e}") raise - + def _create_database_if_not_exists(self): """Create the database if it doesn't exist.""" try: @@ -123,29 +162,32 @@ class DatabaseConnector: database="postgres", user=self.dbUser, password=self.dbPassword, - client_encoding='utf8' + client_encoding="utf8", ) conn.autocommit = True - + with conn.cursor() as cursor: # Check if database exists - cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)) + cursor.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) + ) exists = cursor.fetchone() - + if not exists: # Create database with proper quoting for names with hyphens quoted_db_name = f'"{self.dbDatabase}"' cursor.execute(f"CREATE DATABASE {quoted_db_name}") logger.info(f"Created database: {self.dbDatabase}") - + conn.close() - + except Exception as e: logger.error(f"FATAL ERROR: Cannot create database: {e}") logger.error("Database connection failed - application cannot start") - raise RuntimeError(f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}") - - + raise RuntimeError( + f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}" + ) + def _create_tables(self): """Create only the system table - application tables are created by interfaces.""" try: @@ -156,10 +198,10 @@ class DatabaseConnector: database=self.dbDatabase, user=self.dbUser, password=self.dbPassword, - client_encoding='utf8' + client_encoding="utf8", ) conn.autocommit = True - + with conn.cursor() as cursor: # Create only the system table cursor.execute(""" @@ -172,12 +214,14 @@ class DatabaseConnector: ) """) conn.close() - + except Exception as e: logger.error(f"FATAL ERROR: Cannot create system table: {e}") - logger.error("Database system table creation failed - application cannot start") + logger.error( + "Database system table creation failed - application cannot start" + ) raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}") - + def _connect(self): """Establish connection to PostgreSQL database.""" try: @@ -188,14 +232,14 @@ class DatabaseConnector: database=self.dbDatabase, user=self.dbUser, password=self.dbPassword, - client_encoding='utf8', - cursor_factory=psycopg2.extras.RealDictCursor + client_encoding="utf8", + cursor_factory=psycopg2.extras.RealDictCursor, ) self.connection.autocommit = False # Use transactions except Exception as e: logger.error(f"Failed to connect to PostgreSQL: {e}") raise - + def _ensure_connection(self): """Ensure database connection is alive, reconnect if necessary.""" try: @@ -208,72 +252,78 @@ class DatabaseConnector: except Exception as e: logger.warning(f"Connection lost, reconnecting: {e}") self._connect() - + def _initializeSystemTable(self): """Initializes the system table if it doesn't exist yet.""" try: # First ensure the system table exists self._ensureTableExists(SystemTable) - + with self.connection.cursor() as cursor: # Check if system table has any data cursor.execute('SELECT COUNT(*) FROM "_system"') row = cursor.fetchone() - count = row['count'] if row else 0 - + count = row["count"] if row else 0 + self.connection.commit() except Exception as e: logger.error(f"Error initializing system table: {e}") self.connection.rollback() raise - + def _loadSystemTable(self) -> Dict[str, str]: """Loads the system table with the initial IDs.""" try: with self.connection.cursor() as cursor: cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') rows = cursor.fetchall() - + system_data = {} for row in rows: - system_data[row['table_name']] = row['initial_id'] - + system_data[row["table_name"]] = row["initial_id"] + return system_data except Exception as e: logger.error(f"Error loading system table: {e}") return {} - + def _saveSystemTable(self, data: Dict[str, str]) -> bool: """Saves the system table with the initial IDs.""" try: with self.connection.cursor() as cursor: # Clear existing data cursor.execute('DELETE FROM "_system"') - + # Insert new data for table_name, initial_id in data.items(): - cursor.execute(""" + cursor.execute( + """ INSERT INTO "_system" ("table_name", "initial_id", "_modifiedAt") VALUES (%s, %s, %s) - """, (table_name, initial_id, get_utc_timestamp())) - + """, + (table_name, initial_id, get_utc_timestamp()), + ) + self.connection.commit() return True except Exception as e: logger.error(f"Error saving system table: {e}") self.connection.rollback() return False - + def _ensureSystemTableExists(self) -> bool: """Ensures the system table exists, creates it if it doesn't.""" try: self._ensure_connection() - + with self.connection.cursor() as cursor: # Check if system table exists - cursor.execute("SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,)) - exists = cursor.fetchone()['count'] > 0 - + cursor.execute( + "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", + (self._systemTableName,), + ) + exists = cursor.fetchone()["count"] > 0 + if not exists: # Create system table cursor.execute(f""" @@ -287,119 +337,142 @@ class DatabaseConnector: logger.info("System table created successfully") else: # Check if we need to add missing columns to existing table - cursor.execute(""" + cursor.execute( + """ SELECT column_name FROM information_schema.columns WHERE table_name = %s AND table_schema = 'public' - """, (self._systemTableName,)) - existing_columns = [row['column_name'] for row in cursor.fetchall()] - - if '_modifiedAt' not in existing_columns: - cursor.execute(f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "_modifiedAt" DOUBLE PRECISION') - + """, + (self._systemTableName,), + ) + existing_columns = [row["column_name"] for row in cursor.fetchall()] + + if "_modifiedAt" not in existing_columns: + cursor.execute( + f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "_modifiedAt" DOUBLE PRECISION' + ) + return True except Exception as e: logger.error(f"Error ensuring system table exists: {e}") return False - + def _ensureTableExists(self, model_class: type) -> bool: """Ensures a table exists, creates it if it doesn't.""" table = model_class.__name__ - + if table == "SystemTable": # Handle system table specially - it uses _system as the actual table name return self._ensureSystemTableExists() - + try: self._ensure_connection() - + with self.connection.cursor() as cursor: # Check if table exists by querying information_schema with case-insensitive search - cursor.execute(''' + cursor.execute( + """ SELECT COUNT(*) FROM information_schema.tables WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' - ''', (table,)) - exists = cursor.fetchone()['count'] > 0 - + """, + (table,), + ) + exists = cursor.fetchone()["count"] > 0 + if not exists: # Create table from Pydantic model self._create_table_from_model(cursor, table, model_class) - logger.info(f"Created table '{table}' with columns from Pydantic model") - + logger.info( + f"Created table '{table}' with columns from Pydantic model" + ) + self.connection.commit() return True except Exception as e: logger.error(f"Error ensuring table {table} exists: {e}") - if hasattr(self, 'connection') and self.connection: + if hasattr(self, "connection") and self.connection: self.connection.rollback() return False - def _create_table_from_model(self, cursor, table: str, model_class: type) -> None: """Create table with columns matching Pydantic model fields.""" fields = _get_model_fields(model_class) - + # Build column definitions with quoted identifiers to preserve exact case columns = ['"id" VARCHAR(255) PRIMARY KEY'] for field_name, sql_type in fields.items(): - if field_name != 'id': # Skip id, already defined + if field_name != "id": # Skip id, already defined columns.append(f'"{field_name}" {sql_type}') - + # Add metadata columns - columns.extend([ - '"_createdAt" DOUBLE PRECISION', - '"_modifiedAt" DOUBLE PRECISION', - '"_createdBy" VARCHAR(255)', - '"_modifiedBy" VARCHAR(255)' - ]) - + columns.extend( + [ + '"_createdAt" DOUBLE PRECISION', + '"_modifiedAt" DOUBLE PRECISION', + '"_createdBy" VARCHAR(255)', + '"_modifiedBy" VARCHAR(255)', + ] + ) + # Create table sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})' cursor.execute(sql) - + # Create indexes for foreign keys for field_name in fields: - if field_name.endswith('Id') and field_name != 'id': - cursor.execute(f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")') + if field_name.endswith("Id") and field_name != "id": + cursor.execute( + f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")' + ) - - def _save_record(self, cursor, table: str, recordId: str, record: Dict[str, Any], model_class: type) -> None: + def _save_record( + self, + cursor, + table: str, + recordId: str, + record: Dict[str, Any], + model_class: type, + ) -> None: """Save record to normalized table with explicit columns.""" # Get columns from Pydantic model instead of database schema fields = _get_model_fields(model_class) - columns = ['id'] + [field for field in fields.keys() if field != 'id'] + ['_createdAt', '_createdBy', '_modifiedAt', '_modifiedBy'] - - + columns = ( + ["id"] + + [field for field in fields.keys() if field != "id"] + + ["_createdAt", "_createdBy", "_modifiedAt", "_modifiedBy"] + ) + if not columns: logger.error(f"No columns found for table {table}") return - + # Filter record data to only include columns that exist in the table filtered_record = {k: v for k, v in record.items() if k in columns} - + # Ensure id is set - filtered_record['id'] = recordId - + filtered_record["id"] = recordId + # Prepare values in the correct order values = [] for col in columns: value = filtered_record.get(col) - + # Handle timestamp fields - store as Unix timestamps (floats) for consistency - if col in ['_createdAt', '_modifiedAt'] and value is not None: + if col in ["_createdAt", "_modifiedAt"] and value is not None: if isinstance(value, str): # Try to parse string as timestamp try: value = float(value) except: pass # Keep as string if parsing fails - + # Convert enum values to their string representation - elif hasattr(value, 'value'): + elif hasattr(value, "value"): value = value.value - + # Handle JSONB fields - ensure proper JSON format for PostgreSQL - elif col in fields and fields[col] == 'JSONB' and value is not None: + elif col in fields and fields[col] == "JSONB" and value is not None: import json + if isinstance(value, (dict, list)): # Convert Python objects to JSON string for PostgreSQL JSONB value = json.dumps(value) @@ -416,42 +489,51 @@ class DatabaseConnector: else: # Convert other types to JSON value = json.dumps(value) - + values.append(value) - - + # Build INSERT/UPDATE with quoted identifiers - col_names = ', '.join([f'"{col}"' for col in columns]) - placeholders = ', '.join(['%s'] * len(columns)) - updates = ', '.join([f'"{col}" = EXCLUDED."{col}"' for col in columns[1:] if col not in ['_createdAt', '_createdBy']]) - + col_names = ", ".join([f'"{col}"' for col in columns]) + placeholders = ", ".join(["%s"] * len(columns)) + updates = ", ".join( + [ + f'"{col}" = EXCLUDED."{col}"' + for col in columns[1:] + if col not in ["_createdAt", "_createdBy"] + ] + ) + sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}' - + cursor.execute(sql, values) - + def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: """Loads a single record from the normalized table.""" table = model_class.__name__ - + try: if not self._ensureTableExists(model_class): return None - + with self.connection.cursor() as cursor: cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) row = cursor.fetchone() if not row: return None - + # Convert row to dict and handle JSONB fields record = dict(row) fields = _get_model_fields(model_class) - - + # Parse JSONB fields back to Python objects for field_name, field_type in fields.items(): - if field_type == 'JSONB' and field_name in record and record[field_name] is not None: + if ( + field_type == "JSONB" + and field_name in record + and record[field_name] is not None + ): import json + try: if isinstance(record[field_name], str): # Parse JSON string back to Python object @@ -464,26 +546,30 @@ class DatabaseConnector: record[field_name] = json.loads(str(record[field_name])) except (json.JSONDecodeError, TypeError, ValueError): # If parsing fails, keep as string - logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}") + logger.warning( + f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}" + ) pass - + return record except Exception as e: logger.error(f"Error loading record {recordId} from table {table}: {e}") return None - - def _saveRecord(self, model_class: type, recordId: str, record: Dict[str, Any]) -> bool: + + def _saveRecord( + self, model_class: type, recordId: str, record: Dict[str, Any] + ) -> bool: """Saves a single record to the table.""" table = model_class.__name__ - + try: if not self._ensureTableExists(model_class): return False - + recordId = str(recordId) if "id" in record and str(record["id"]) != recordId: raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}") - + # Add metadata currentTime = get_utc_timestamp() if "_createdAt" not in record: @@ -491,74 +577,85 @@ class DatabaseConnector: record["_createdBy"] = self.userId record["_modifiedAt"] = currentTime record["_modifiedBy"] = self.userId - + with self.connection.cursor() as cursor: self._save_record(cursor, table, recordId, record, model_class) - + self.connection.commit() return True except Exception as e: logger.error(f"Error saving record {recordId} to table {table}: {e}") self.connection.rollback() return False - + def _loadTable(self, model_class: type) -> List[Dict[str, Any]]: """Loads all records from a normalized table.""" table = model_class.__name__ - + if table == self._systemTableName: return self._loadSystemTable() - + try: if not self._ensureTableExists(model_class): return [] - + with self.connection.cursor() as cursor: cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') records = [dict(row) for row in cursor.fetchall()] - + # Handle JSONB fields for all records fields = _get_model_fields(model_class) for record in records: for field_name, field_type in fields.items(): - if field_type == 'JSONB' and field_name in record: + if field_type == "JSONB" and field_name in record: if record[field_name] is None: # Convert None to appropriate default based on field name - if field_name in ['logs', 'messages', 'tasks', 'expectedDocumentFormats', 'resultDocuments']: + if field_name in [ + "logs", + "messages", + "tasks", + "expectedDocumentFormats", + "resultDocuments", + ]: record[field_name] = [] - elif field_name in ['execParameters', 'stats']: + elif field_name in ["execParameters", "stats"]: record[field_name] = {} else: record[field_name] = None else: import json + try: if isinstance(record[field_name], str): # Parse JSON string back to Python object - record[field_name] = json.loads(record[field_name]) + record[field_name] = json.loads( + record[field_name] + ) elif isinstance(record[field_name], (dict, list)): # Already a Python object, keep as is pass else: # Try to parse as JSON - record[field_name] = json.loads(str(record[field_name])) + record[field_name] = json.loads( + str(record[field_name]) + ) except (json.JSONDecodeError, TypeError, ValueError): # If parsing fails, keep as string - logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}") + logger.warning( + f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}" + ) pass - + return records except Exception as e: logger.error(f"Error loading table {table}: {e}") return [] - - - + def _registerInitialId(self, table: str, initialId: str) -> bool: """Registers the initial ID for a table.""" try: systemData = self._loadSystemTable() - + if table not in systemData: systemData[table] = initialId success = self._saveSystemTable(systemData) @@ -568,58 +665,64 @@ class DatabaseConnector: else: # Check if the existing initial ID still exists in the table existingInitialId = systemData[table] - records = self.getRecordset(model_class, recordFilter={"id": existingInitialId}) + records = self.getRecordset( + model_class, recordFilter={"id": existingInitialId} + ) if not records: # The initial record no longer exists, update to the new one systemData[table] = initialId success = self._saveSystemTable(systemData) if success: - logger.info(f"Initial ID updated from {existingInitialId} to {initialId} for table {table}") + logger.info( + f"Initial ID updated from {existingInitialId} to {initialId} for table {table}" + ) return success else: return True except Exception as e: logger.error(f"Error registering the initial ID for table {table}: {e}") return False - + def _removeInitialId(self, table: str) -> bool: """Removes the initial ID for a table from the system table.""" try: systemData = self._loadSystemTable() - + if table in systemData: del systemData[table] success = self._saveSystemTable(systemData) if success: - logger.info(f"Initial ID for table {table} removed from system table") + logger.info( + f"Initial ID for table {table} removed from system table" + ) return success return True # If not present, this is not an error except Exception as e: logger.error(f"Error removing initial ID for table {table}: {e}") return False - + def updateContext(self, userId: str) -> None: """Updates the context of the database connector.""" if userId is None: raise ValueError("userId must be provided") - + self.userId = userId # No cache to clear - database handles data consistency - + # Public API - + def getTables(self) -> List[str]: """Returns a list of all available tables.""" tables = [] - + try: # Ensure connection is alive self._ensure_connection() - + if not self.connection or self.connection.closed: logger.error("Database connection is not available") return tables - + with self.connection.cursor() as cursor: cursor.execute(""" SELECT table_name @@ -628,104 +731,121 @@ class DatabaseConnector: ORDER BY table_name """) rows = cursor.fetchall() - tables = [row['table_name'] for row in rows] + tables = [row["table_name"] for row in rows] except Exception as e: logger.error(f"Error reading the database {self.dbDatabase}: {e}") - + return tables - + def getFields(self, model_class: type) -> List[str]: """Returns a list of all fields in a table.""" data = self._loadTable(model_class) - + if not data: return [] - + fields = list(data[0].keys()) if data else [] - + return fields - - def getSchema(self, model_class: type, language: str = None) -> Dict[str, Dict[str, Any]]: + + def getSchema( + self, model_class: type, language: str = None + ) -> Dict[str, Dict[str, Any]]: """Returns a schema object for a table with data types and labels.""" data = self._loadTable(model_class) - + schema = {} - + if not data: return schema - + firstRecord = data[0] - + for field, value in firstRecord.items(): dataType = type(value).__name__ label = field - - schema[field] = { - "type": dataType, - "label": label - } - + + schema[field] = {"type": dataType, "label": label} + return schema - - def getRecordset(self, model_class: type, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]: + + def getRecordset( + self, + model_class: type, + fieldFilter: List[str] = None, + recordFilter: Dict[str, Any] = None, + ) -> List[Dict[str, Any]]: """Returns a list of records from a table, filtered by criteria.""" table = model_class.__name__ - + try: if not self._ensureTableExists(model_class): return [] - + # Build WHERE clause from recordFilter where_conditions = [] where_values = [] - + if recordFilter: for field, value in recordFilter.items(): where_conditions.append(f'"{field}" = %s') where_values.append(value) - + # Build the query if where_conditions: where_clause = " WHERE " + " AND ".join(where_conditions) else: where_clause = "" - + query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"' - + with self.connection.cursor() as cursor: cursor.execute(query, where_values) records = [dict(row) for row in cursor.fetchall()] - + # Handle JSONB fields for all records fields = _get_model_fields(model_class) for record in records: for field_name, field_type in fields.items(): - if field_type == 'JSONB' and field_name in record: + if field_type == "JSONB" and field_name in record: if record[field_name] is None: # Convert None to appropriate default based on field name - if field_name in ['logs', 'messages', 'tasks', 'expectedDocumentFormats', 'resultDocuments']: + if field_name in [ + "logs", + "messages", + "tasks", + "expectedDocumentFormats", + "resultDocuments", + ]: record[field_name] = [] - elif field_name in ['execParameters', 'stats']: + elif field_name in ["execParameters", "stats"]: record[field_name] = {} else: record[field_name] = None else: import json + try: if isinstance(record[field_name], str): # Parse JSON string back to Python object - record[field_name] = json.loads(record[field_name]) + record[field_name] = json.loads( + record[field_name] + ) elif isinstance(record[field_name], (dict, list)): # Already a Python object, keep as is pass else: # Try to parse as JSON - record[field_name] = json.loads(str(record[field_name])) + record[field_name] = json.loads( + str(record[field_name]) + ) except (json.JSONDecodeError, TypeError, ValueError): # If parsing fails, keep as string - logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}") + logger.warning( + f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}" + ) pass - + # If fieldFilter is available, reduce the fields if fieldFilter and isinstance(fieldFilter, list): result = [] @@ -736,13 +856,15 @@ class DatabaseConnector: filteredRecord[field] = record[field] result.append(filteredRecord) return result - + return records except Exception as e: logger.error(f"Error loading records from table {table}: {e}") return [] - - def recordCreate(self, model_class: type, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]: + + def recordCreate( + self, model_class: type, record: Union[Dict[str, Any], BaseModel] + ) -> Dict[str, Any]: """Creates a new record in a table based on Pydantic model class.""" # If record is a Pydantic model, convert to dict if isinstance(record, BaseModel): @@ -751,14 +873,14 @@ class DatabaseConnector: record = record.copy() else: raise ValueError("Record must be a Pydantic model or dictionary") - + # Ensure record has an ID if "id" not in record: record["id"] = str(uuid.uuid4()) - + # Save record self._saveRecord(model_class, record["id"], record) - + # Check if this is the first record in the table and register as initial ID table = model_class.__name__ existingInitialId = self.getInitialId(model_class) @@ -766,17 +888,19 @@ class DatabaseConnector: # This is the first record, register it as the initial ID self._registerInitialId(table, record["id"]) logger.info(f"Registered initial ID {record['id']} for table {table}") - + return record - - def recordModify(self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]: + + def recordModify( + self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel] + ) -> Dict[str, Any]: """Modifies an existing record in a table based on Pydantic model class.""" # Load existing record existingRecord = self._loadRecord(model_class, recordId) if not existingRecord: table = model_class.__name__ raise ValueError(f"Record {recordId} not found in table {table}") - + # If record is a Pydantic model, convert to dict if isinstance(record, BaseModel): record = to_dict(record) @@ -784,15 +908,19 @@ class DatabaseConnector: record = record.copy() else: raise ValueError("Record must be a Pydantic model or dictionary") - + # CRITICAL: Ensure we never modify the ID if "id" in record and str(record["id"]) != recordId: - logger.error(f"Attempted to modify record ID from {recordId} to {record['id']}") - raise ValueError("Cannot modify record ID - it must match the provided recordId") - + logger.error( + f"Attempted to modify record ID from {recordId} to {record['id']}" + ) + raise ValueError( + "Cannot modify record ID - it must match the provided recordId" + ) + # Update existing record with new data existingRecord.update(record) - + # Save updated record self._saveRecord(model_class, recordId, existingRecord) return existingRecord @@ -800,49 +928,56 @@ class DatabaseConnector: def recordDelete(self, model_class: type, recordId: str) -> bool: """Deletes a record from the table based on Pydantic model class.""" table = model_class.__name__ - + try: if not self._ensureTableExists(model_class): return False - + with self.connection.cursor() as cursor: # Check if record exists - cursor.execute(f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)) + cursor.execute( + f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) + ) if not cursor.fetchone(): return False - + # Check if it's an initial record initialId = self.getInitialId(model_class) if initialId is not None and initialId == recordId: self._removeInitialId(table) - logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table") - + logger.info( + f"Initial ID {recordId} for table {table} has been removed from the system table" + ) + # Delete the record cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,)) - + # No cache to update - database handles consistency - + self.connection.commit() return True - + except Exception as e: logger.error(f"Error deleting record {recordId} from table {table}: {e}") self.connection.rollback() return False - def getInitialId(self, model_class: type) -> Optional[str]: """Returns the initial ID for a table.""" table = model_class.__name__ systemData = self._loadSystemTable() initialId = systemData.get(table) return initialId - + def close(self): """Close the database connection.""" - if hasattr(self, 'connection') and self.connection and not self.connection.closed: + if ( + hasattr(self, "connection") + and self.connection + and not self.connection.closed + ): self.connection.close() - + def __del__(self): """Cleanup method to close connection.""" try: From d33241e5dcbcecd4ca3f8b3f2fbdf4529b6c56c4 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 17:00:04 +0200 Subject: [PATCH 09/29] chore: add authorize button to swagger docs --- app.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/app.py b/app.py index 23eeb645..2bc6d324 100644 --- a/app.py +++ b/app.py @@ -4,6 +4,8 @@ os.environ["NUMEXPR_MAX_THREADS"] = "12" from fastapi import FastAPI, HTTPException, Depends, Body, status, Response from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel +from fastapi.security import HTTPBearer from contextlib import asynccontextmanager @@ -268,8 +270,50 @@ app = FastAPI( title="PowerOn | Data Platform API", description=f"Backend API for the Multi-Agent Platform by ValueOn AG ({instanceLabel})", lifespan=lifespan, + swagger_ui_init_oauth={ + "usePkceWithAuthorizationCodeGrant": True, + }, ) +# Configure OpenAPI security scheme for Swagger UI +# This adds the "Authorize" button to the /docs page +security_scheme = HTTPBearer() +app.openapi_schema = None # Reset schema to regenerate with security + + +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + + from fastapi.openapi.utils import get_openapi + + openapi_schema = get_openapi( + title=app.title, + version="1.0.0", + description=app.description, + routes=app.routes, + ) + + # Add security scheme definition + openapi_schema["components"]["securitySchemes"] = { + "BearerAuth": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + "description": "Enter your JWT token (obtained from login endpoint or browser cookies)", + } + } + + # Apply security globally to all endpoints + # Individual endpoints can override this if needed + openapi_schema["security"] = [{"BearerAuth": []}] + + app.openapi_schema = openapi_schema + return app.openapi_schema + + +app.openapi = custom_openapi + # Parse CORS origins from environment variable def get_allowed_origins(): From 33f8ff1b5e34fc8eda45aaac43e7e6db9377b036 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Mon, 6 Oct 2025 15:05:22 +0200 Subject: [PATCH 10/29] chore: better error messages --- modules/features/chatBot/service.py | 5 ++-- modules/routes/routeChatbot.py | 36 +++++++++++++++++++---------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 1899d8e9..b88f1753 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -202,13 +202,14 @@ async def post_message_stream( return except Exception as e: - logger.error(f"Error in streaming chat: {str(e)}", exc_info=True) + error_msg = f"{type(e).__name__}: {str(e) or 'No error message provided'}" + logger.error(f"Error in streaming chat: {error_msg}", exc_info=True) yield ( "data: " + json.dumps( { "type": "error", - "message": "An error occurred while processing your request.", + "message": f"An error occurred while processing your request: {error_msg}", } ) + "\n\n" diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 01617582..c0ee5bee 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -67,10 +67,12 @@ async def post_chat_message_stream( ) except Exception as e: - logger.error(f"Error posting chat message: {str(e)}") + logger.error( + f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to post message: {str(e)}", + detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}", ) @@ -103,16 +105,18 @@ async def post_chat_message( return response except ValueError as e: - logger.error(f"Permission error: {str(e)}") + logger.error(f"Permission error: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=str(e), + detail=str(e) or "Permission denied", ) except Exception as e: - logger.error(f"Error posting chat message: {str(e)}") + logger.error( + f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to post message: {str(e)}", + detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}", ) @@ -155,10 +159,12 @@ async def get_all_threads( return ThreadListResponse(threads=dummy_threads) except Exception as e: - logger.error(f"Error retrieving threads: {str(e)}") + logger.error( + f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve threads: {str(e)}", + detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}", ) @@ -207,10 +213,13 @@ async def get_thread_by_id( ) except Exception as e: - logger.error(f"Error retrieving thread {thread_id}: {str(e)}") + logger.error( + f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}", + exc_info=True, + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve thread: {str(e)}", + detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}", ) @@ -238,8 +247,11 @@ async def delete_thread( ) except Exception as e: - logger.error(f"Error deleting thread {thread_id}: {str(e)}") + logger.error( + f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}", + exc_info=True, + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete thread: {str(e)}", + detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}", ) From dd16efb860b6d33b9a4a5872b5b065f4767c6f26 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Mon, 6 Oct 2025 15:39:19 +0200 Subject: [PATCH 11/29] fix: typo; async checkpointer postgres --- modules/features/chatBot/domain/chatbot.py | 2 +- modules/features/chatBot/utils/checkpointer.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/modules/features/chatBot/domain/chatbot.py b/modules/features/chatBot/domain/chatbot.py index 56129097..548ea966 100644 --- a/modules/features/chatBot/domain/chatbot.py +++ b/modules/features/chatBot/domain/chatbot.py @@ -44,7 +44,7 @@ def get_langchain_model(*, model_name: str) -> ChatAnthropic: """ # Model name mapping model_mapping = { - "claude_4_5": "claude-4-5-sonnet", + "claude_4_5": "claude-sonnet-4-5", # Add more mappings as needed } diff --git a/modules/features/chatBot/utils/checkpointer.py b/modules/features/chatBot/utils/checkpointer.py index 0aebbda6..957f5f5d 100644 --- a/modules/features/chatBot/utils/checkpointer.py +++ b/modules/features/chatBot/utils/checkpointer.py @@ -3,14 +3,15 @@ import logging from typing import Optional -from langgraph.checkpoint.postgres import PostgresSaver +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from psycopg_pool import AsyncConnectionPool +from psycopg.rows import dict_row from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) # Global checkpointer instance -_checkpointer_instance: Optional[PostgresSaver] = None +_checkpointer_instance: Optional[AsyncPostgresSaver] = None _connection_pool: Optional[AsyncConnectionPool] = None @@ -42,13 +43,14 @@ async def initialize_checkpointer() -> None: conninfo=connection_string, min_size=2, max_size=10, + kwargs={"autocommit": True, "row_factory": dict_row}, ) # Initialize the connection pool await _connection_pool.open() - # Create PostgresSaver with the pool - _checkpointer_instance = PostgresSaver(_connection_pool) + # Create AsyncPostgresSaver with the pool + _checkpointer_instance = AsyncPostgresSaver(_connection_pool) # Setup the checkpointer (creates tables if needed) await _checkpointer_instance.setup() @@ -78,11 +80,11 @@ async def close_checkpointer() -> None: _connection_pool = None -def get_checkpointer() -> PostgresSaver: +def get_checkpointer() -> AsyncPostgresSaver: """Get the global PostgreSQL checkpointer instance. Returns: - The initialized PostgresSaver instance + The initialized AsyncPostgresSaver instance Raises: RuntimeError: If checkpointer is not initialized From 2a795492fe56a0623a148a27b4e3943660439d1b Mon Sep 17 00:00:00 2001 From: Ida Dittrich Date: Wed, 8 Oct 2025 10:24:15 +0200 Subject: [PATCH 12/29] fixed windows error --- app.py | 6 ++++++ modules/features/chatBot/utils/checkpointer.py | 9 +++++++++ requirements.txt | 2 +- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index 2bc6d324..7d92601f 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,13 @@ import os +import sys +import asyncio os.environ["NUMEXPR_MAX_THREADS"] = "12" +# Fix for Windows asyncio compatibility with psycopg +if sys.platform == 'win32': + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + from fastapi import FastAPI, HTTPException, Depends, Body, status, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel diff --git a/modules/features/chatBot/utils/checkpointer.py b/modules/features/chatBot/utils/checkpointer.py index 957f5f5d..a51e7455 100644 --- a/modules/features/chatBot/utils/checkpointer.py +++ b/modules/features/chatBot/utils/checkpointer.py @@ -1,8 +1,17 @@ """PostgreSQL checkpointer utilities for LangGraph memory.""" +import sys +import asyncio import logging from typing import Optional +# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply) +if sys.platform == 'win32': + try: + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + except RuntimeError: + pass # Already set + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from psycopg_pool import AsyncConnectionPool from psycopg.rows import dict_row diff --git a/requirements.txt b/requirements.txt index 2378ca97..4a089c4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ ## Web Framework & API -fastapi==0.104.1 +fastapi==0.115.0 # Upgraded for Pydantic v2 compatibility websockets==12.0 uvicorn==0.23.2 python-multipart==0.0.6 From 8f96c3ef30aab11b8d639303fd066f9bd2fed9db Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 10:46:33 +0200 Subject: [PATCH 13/29] chore: add todos --- modules/routes/routeChatbot.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index c0ee5bee..48f74e5b 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -46,6 +46,10 @@ async def post_chat_message_stream( Returns Server-Sent Events (SSE) stream with status updates and final response. """ try: + # TODO: Add helper here, if no thread id is provided, add entry in mapping table. + + # TODO: If not provided, create new thread in LangGraph's checkpointer, and add it to mapping table. + # Generate or use existing thread_id thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" From 30c3f9f7f1263d8e2c552c03774ca7ac4dd08283 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 11:37:36 +0200 Subject: [PATCH 14/29] chore: add user threads db table setup --- app.py | 41 ++++++++++++++++++++++++++-- modules/features/chatBot/database.py | 40 +++++++++++++++++++++++++++ requirements.txt | 2 ++ 3 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 modules/features/chatBot/database.py diff --git a/app.py b/app.py index 7d92601f..b75b59e8 100644 --- a/app.py +++ b/app.py @@ -1,11 +1,14 @@ import os import sys import asyncio +from urllib.parse import quote_plus +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession +from modules.features.chatBot.database import init_models as init_chatbot_models os.environ["NUMEXPR_MAX_THREADS"] = "12" # Fix for Windows asyncio compatibility with psycopg -if sys.platform == 'win32': +if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) from fastapi import FastAPI, HTTPException, Depends, Body, status, Response @@ -238,6 +241,15 @@ def initLogging(): ) +def make_sqlalchemy_db_url() -> str: + host = APP_CONFIG.get("SQLALCHEMY_DB_HOST", "localhost") + port = APP_CONFIG.get("SQLALCHEMY_DB_PORT", "5432") + db = APP_CONFIG.get("SQLALCHEMY_DB_DATABASE", "project_gateway") + user = APP_CONFIG.get("SQLALCHEMY_DB_USER", "postgres") + pwd = quote_plus(APP_CONFIG.get("SQLALCHEMY_DB_PASSWORD_SECRET", "")) + return f"postgresql+psycopg://{user}:{pwd}@{host}:{port}/{db}" + + # Initialize logging initLogging() logger = logging.getLogger(__name__) @@ -249,7 +261,22 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") async def lifespan(app: FastAPI): logger.info("Application is starting up") - # Initialize LangGraph checkpointer + # --- Init SQLAlchemy --- + + engine = create_async_engine( + make_sqlalchemy_db_url(), pool_pre_ping=True, echo=False + ) + SessionLocal = async_sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + app.state.checkpoint_engine = engine + app.state.checkpoint_sessionmaker = SessionLocal + + # NOTE: Might need Alembic migrations in the future + await init_chatbot_models(engine) + + # --- Initialize LangGraph checkpointer --- + from modules.features.chatBot.utils.checkpointer import ( initialize_checkpointer, close_checkpointer, @@ -262,12 +289,20 @@ async def lifespan(app: FastAPI): logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}") # Continue startup even if checkpointer fails to initialize + # --- Init Event Manager --- eventManager.start() + yield - # Cleanup + # --- Cleanup Event Manager --- eventManager.stop() + + # --- Cleanup LangGraph checkpointer --- await close_checkpointer() + + # --- Cleanup SQLAlchemy --- + await engine.dispose() + logger.info("Application has been shut down") diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py new file mode 100644 index 00000000..fc190205 --- /dev/null +++ b/modules/features/chatBot/database.py @@ -0,0 +1,40 @@ +from typing import AsyncIterator +from fastapi import Request +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy import String + + +class Base(DeclarativeBase): + pass + + +# User Thread Mapping Table +class UserThreadMapping(Base): + """Mapping of users to their chat threads. + + Used to keep track of which user owns which chat thread. + Also stores meta data like thread name. + """ + + __tablename__ = "userThreads" + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + userId: Mapped[int] = mapped_column(nullable=False) + threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + threadName: Mapped[str] = mapped_column(String(255), nullable=False) + + +# Dependency that pulls the sessionmaker off app.state +# This is set in app.py on startup in @asynccontextmanager +async def get_session(request: Request) -> AsyncIterator[AsyncSession]: + SessionLocal: async_sessionmaker[AsyncSession] = ( + request.app.state.checkpoint_sessionmaker + ) + async with SessionLocal() as session: + yield session + + +# Optional helper to init tables at startup (demo only) +async def init_models(engine) -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) diff --git a/requirements.txt b/requirements.txt index 4a089c4b..6e8f2399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -117,3 +117,5 @@ langchain-anthropic==0.3.1 # For Claude models psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer) psycopg-pool==3.2.1 # Connection pooling for PostgreSQL langgraph-checkpoint-postgres==2.0.24 + +greenlet==3.2.4 \ No newline at end of file From b50dcc6c0f96e60e476c8c9d05e3ffe1edf5851b Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 14:41:29 +0200 Subject: [PATCH 15/29] feat: save chatbot threads to db --- modules/features/chatBot/database.py | 15 ++- modules/features/chatBot/service.py | 166 ++++++++++++++++++++++++++- modules/routes/routeChatbot.py | 26 +++-- 3 files changed, 194 insertions(+), 13 deletions(-) diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py index fc190205..ba67a28b 100644 --- a/modules/features/chatBot/database.py +++ b/modules/features/chatBot/database.py @@ -1,8 +1,10 @@ from typing import AsyncIterator +import uuid + from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import String +from sqlalchemy import String, Uuid class Base(DeclarativeBase): @@ -15,18 +17,23 @@ class UserThreadMapping(Base): Used to keep track of which user owns which chat thread. Also stores meta data like thread name. + + 1:N relationship between user and thread. A thread belongs to exactly one user. + A user can have multiple threads. + Thread_id is unique in the table. """ __tablename__ = "userThreads" - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - userId: Mapped[int] = mapped_column(nullable=False) + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + userId: Mapped[str] = mapped_column(String(255), nullable=False) threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) threadName: Mapped[str] = mapped_column(String(255), nullable=False) # Dependency that pulls the sessionmaker off app.state # This is set in app.py on startup in @asynccontextmanager -async def get_session(request: Request) -> AsyncIterator[AsyncSession]: +# TODO: If we use SQLAlchemy in other places, we can move this to a shared module +async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]: SessionLocal: async_sessionmaker[AsyncSession] = ( request.app.state.checkpoint_sessionmaker ) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index b88f1753..86442ae6 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -2,12 +2,16 @@ import json import logging -from typing import AsyncIterator, List +from typing import AsyncIterator, List, Optional + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model from modules.features.chatBot.utils.checkpointer import get_checkpointer from modules.features.chatBot.utils.toolRegistry import get_registry from modules.features.chatBot.utils import permissions +from modules.features.chatBot.database import UserThreadMapping from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse from modules.datamodels.datamodelUam import User @@ -17,6 +21,166 @@ from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) +async def save_thread_for_user( + *, + thread_id: str, + user: User, + session: AsyncSession, + thread_name: str = "New Chat", + title: str = "New Chat", +) -> None: + """Save a new chat thread mapping for the user. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for saving. + thread_name: The name of the chat thread. Defaults to "New Chat". + title: Optional title for the chat (currently unused). + """ + logger.info(f"Saving new thread {thread_id} for user {user.id}") + + # Create new mapping entry + new_mapping = UserThreadMapping( + userId=user.id, + threadId=thread_id, + threadName=thread_name, + ) + + session.add(new_mapping) + await session.commit() + + logger.info(f"Successfully saved thread {thread_id} for user {user.id}") + + +async def get_or_create_thread_for_user( + *, + thread_id: Optional[str], + user: User, + session: AsyncSession, + thread_name: str = "New Chat", +) -> str: + """Get an existing thread or create a new one for the user. + + If thread_id is provided, verifies it exists and belongs to the user. + If thread_id is None, generates a new thread_id and saves it. + + Args: + thread_id: Optional thread identifier. If None, creates a new thread. + user: The current user. + session: The database session for querying/saving. + thread_name: The name for the thread if creating new. Defaults to "New Chat". + + Returns: + The thread_id to use (either the provided one or newly created). + + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the provided thread_id does not exist. + """ + if thread_id: + # If the user provided a thread_id, verify it exists and belongs to them + await assure_thread_exists_and_belongs_to_user( + thread_id=thread_id, user=user, session=session + ) + logger.info(f"Using existing thread {thread_id} for user {user.id}") + return thread_id + else: + # Generate new thread_id if the user did not provide a thread_id + import uuid + + new_thread_id = f"thread_{uuid.uuid4()}" + await save_thread_for_user( + thread_id=new_thread_id, + user=user, + session=session, + thread_name=thread_name, + ) + logger.info(f"Created new thread {new_thread_id} for user {user.id}") + return new_thread_id + + +async def assure_thread_exists_and_belongs_to_user( + *, + thread_id: str, + user: User, + session: AsyncSession, +) -> None: + """Ensure that the given thread ID exists and belongs to the specified user. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for querying. + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the thread does not exist. + """ + # Query the database for the thread mapping + stmt = select(UserThreadMapping).where(UserThreadMapping.threadId == thread_id) + result = await session.execute(stmt) + thread_mapping = result.scalar_one_or_none() + + # Check if thread exists + if thread_mapping is None: + logger.warning(f"Thread {thread_id} does not exist") + raise ValueError(f"Thread {thread_id} does not exist") + + # Check if thread belongs to the user + if thread_mapping.userId != user.id: + logger.warning( + f"User {user.id} attempted to access thread {thread_id} " + f"belonging to user {thread_mapping.userId}" + ) + raise PermissionError( + f"You do not have permission to access thread {thread_id}" + ) + + logger.info(f"Thread {thread_id} verified for user {user.id}") + + +async def update_thread_name( + *, + thread_id: str, + user: User, + new_thread_name: str, + session: AsyncSession, +) -> None: + """Update the name of an existing chat thread. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + new_thread_name: The new name to set for the thread. + session: The database session for updating. + + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the thread does not exist. + """ + # Verify thread exists and belongs to user + await assure_thread_exists_and_belongs_to_user( + thread_id=thread_id, + user=user, + session=session, + ) + + logger.info( + f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}" + ) + + # Update the thread name + stmt = ( + update(UserThreadMapping) + .where(UserThreadMapping.threadId == thread_id) + .values(threadName=new_thread_name) + ) + await session.execute(stmt) + await session.commit() + + logger.info(f"Successfully updated thread {thread_id} name for user {user.id}") + + async def post_message( *, thread_id: str, diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 48f74e5b..8f3efd65 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -5,7 +5,11 @@ from typing import Any, Dict, List, Optional from datetime import datetime import logging import uuid +from sqlalchemy.ext.asyncio import AsyncSession + +from modules.features.chatBot.database import get_async_db_session +from modules.features.chatBot.service import get_or_create_thread_for_user from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelChatbot import ( ChatMessageRequest, @@ -38,6 +42,7 @@ async def post_chat_message_stream( request: Request, message_request: ChatMessageRequest, currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), ) -> StreamingResponse: """ Post a message to a chat thread with streaming progress updates. @@ -46,12 +51,12 @@ async def post_chat_message_stream( Returns Server-Sent Events (SSE) stream with status updates and final response. """ try: - # TODO: Add helper here, if no thread id is provided, add entry in mapping table. - - # TODO: If not provided, create new thread in LangGraph's checkpointer, and add it to mapping table. - - # Generate or use existing thread_id - thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" + # Get or create thread using helper function + thread_id = await get_or_create_thread_for_user( + thread_id=message_request.thread_id, + user=currentUser, + session=session, + ) logger.info( f"User {currentUser.id} posted streaming message to thread {thread_id}" @@ -87,6 +92,7 @@ async def post_chat_message( request: Request, message_request: ChatMessageRequest, currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), ) -> ChatMessageResponse: """ Post a message to a chat thread and get assistant response (non-streaming). @@ -95,8 +101,12 @@ async def post_chat_message( For streaming updates, use the /message/stream endpoint instead. """ try: - # Generate or use existing thread_id - thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" + # Get or create thread using helper function + thread_id = await get_or_create_thread_for_user( + thread_id=message_request.thread_id, + user=currentUser, + session=session, + ) logger.info(f"User {currentUser.id} posted message to thread {thread_id}") From 1143e181e851a6fba046a263af3fc79a06a74005 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 15:30:14 +0200 Subject: [PATCH 16/29] feat: add and handle chatbot thread date_created and date_modified --- modules/features/chatBot/database.py | 13 +++- modules/features/chatBot/service.py | 98 +++++++++++++++++++++++----- modules/routes/routeChatbot.py | 8 ++- 3 files changed, 102 insertions(+), 17 deletions(-) diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py index ba67a28b..50dc9bba 100644 --- a/modules/features/chatBot/database.py +++ b/modules/features/chatBot/database.py @@ -1,10 +1,11 @@ from typing import AsyncIterator import uuid +from datetime import datetime, timezone from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import String, Uuid +from sqlalchemy import String, Uuid, DateTime class Base(DeclarativeBase): @@ -28,6 +29,16 @@ class UserThreadMapping(Base): userId: Mapped[str] = mapped_column(String(255), nullable=False) threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) threadName: Mapped[str] = mapped_column(String(255), nullable=False) + date_created: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + date_updated: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) # Dependency that pulls the sessionmaker off app.state diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 86442ae6..05df7f0f 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -2,6 +2,7 @@ import json import logging +from datetime import datetime, timezone from typing import AsyncIterator, List, Optional from sqlalchemy import select, update @@ -27,7 +28,6 @@ async def save_thread_for_user( user: User, session: AsyncSession, thread_name: str = "New Chat", - title: str = "New Chat", ) -> None: """Save a new chat thread mapping for the user. @@ -36,7 +36,6 @@ async def save_thread_for_user( user: The current user. session: The database session for saving. thread_name: The name of the chat thread. Defaults to "New Chat". - title: Optional title for the chat (currently unused). """ logger.info(f"Saving new thread {thread_id} for user {user.id}") @@ -59,6 +58,7 @@ async def get_or_create_thread_for_user( user: User, session: AsyncSession, thread_name: str = "New Chat", + refresh_date_updated: bool = False, ) -> str: """Get an existing thread or create a new one for the user. @@ -70,6 +70,7 @@ async def get_or_create_thread_for_user( user: The current user. session: The database session for querying/saving. thread_name: The name for the thread if creating new. Defaults to "New Chat". + refresh_date_updated: If True, refreshes date_updated for existing threads. Defaults to False. Returns: The thread_id to use (either the provided one or newly created). @@ -84,6 +85,13 @@ async def get_or_create_thread_for_user( thread_id=thread_id, user=user, session=session ) logger.info(f"Using existing thread {thread_id} for user {user.id}") + + # Refresh date_updated if requested + if refresh_date_updated: + await refresh_thread_date_updated( + thread_id=thread_id, user=user, session=session + ) + return thread_id else: # Generate new thread_id if the user did not provide a thread_id @@ -148,6 +156,10 @@ async def update_thread_name( ) -> None: """Update the name of an existing chat thread. + This function performs security checks by including both threadId and userId + in the WHERE clause of the UPDATE query, ensuring users can only update + threads that belong to them. No separate permission check is needed. + Args: thread_id: The unique identifier for the chat thread. user: The current user. @@ -155,32 +167,88 @@ async def update_thread_name( session: The database session for updating. Raises: - PermissionError: If the thread does not belong to the user. - ValueError: If the thread does not exist. + ValueError: If the thread does not exist or does not belong to the user. """ - # Verify thread exists and belongs to user - await assure_thread_exists_and_belongs_to_user( - thread_id=thread_id, - user=user, - session=session, - ) - logger.info( f"Updating thread {thread_id} name to '{new_thread_name}' for user {user.id}" ) - # Update the thread name + # Update the thread name and date_updated + # Security check: WHERE clause includes both threadId AND userId stmt = ( update(UserThreadMapping) - .where(UserThreadMapping.threadId == thread_id) - .values(threadName=new_thread_name) + .where( + UserThreadMapping.threadId == thread_id, + UserThreadMapping.userId == user.id, + ) + .values(threadName=new_thread_name, date_updated=datetime.now(timezone.utc)) ) - await session.execute(stmt) + result = await session.execute(stmt) await session.commit() + # Check if any rows were affected + if result.rowcount == 0: + logger.warning( + f"Failed to update thread {thread_id} for user {user.id} - " + "thread does not exist or does not belong to user" + ) + raise ValueError( + f"Thread {thread_id} does not exist or you do not have permission to access it" + ) + logger.info(f"Successfully updated thread {thread_id} name for user {user.id}") +async def refresh_thread_date_updated( + *, + thread_id: str, + user: User, + session: AsyncSession, +) -> None: + """Refresh the date_updated timestamp for an existing chat thread. + + This function performs security checks by including both threadId and userId + in the WHERE clause of the UPDATE query, ensuring users can only update + threads that belong to them. No separate permission check is needed. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for updating. + + Raises: + ValueError: If the thread does not exist or does not belong to the user. + """ + logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}") + + # Update the date_updated timestamp + # Security check: WHERE clause includes both threadId AND userId + stmt = ( + update(UserThreadMapping) + .where( + UserThreadMapping.threadId == thread_id, + UserThreadMapping.userId == user.id, + ) + .values(date_updated=datetime.now(timezone.utc)) + ) + result = await session.execute(stmt) + await session.commit() + + # Check if any rows were affected + if result.rowcount == 0: + logger.warning( + f"Failed to refresh thread {thread_id} for user {user.id} - " + "thread does not exist or does not belong to user" + ) + raise ValueError( + f"Thread {thread_id} does not exist or you do not have permission to access it" + ) + + logger.info( + f"Successfully refreshed date_updated for thread {thread_id} for user {user.id}" + ) + + async def post_message( *, thread_id: str, diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 8f3efd65..e14f38c6 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -9,7 +9,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from modules.features.chatBot.database import get_async_db_session -from modules.features.chatBot.service import get_or_create_thread_for_user +from modules.features.chatBot.service import ( + get_or_create_thread_for_user, +) from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelChatbot import ( ChatMessageRequest, @@ -56,6 +58,8 @@ async def post_chat_message_stream( thread_id=message_request.thread_id, user=currentUser, session=session, + thread_name=message_request.message[:100], + refresh_date_updated=True, ) logger.info( @@ -106,6 +110,8 @@ async def post_chat_message( thread_id=message_request.thread_id, user=currentUser, session=session, + thread_name=message_request.message[:100], + refresh_date_updated=True, ) logger.info(f"User {currentUser.id} posted message to thread {thread_id}") From a08bd3ef1db4b9f58b57477b34815c8b56ce7ea0 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 15:45:41 +0200 Subject: [PATCH 17/29] feat: implement get threads endpoint --- modules/datamodels/datamodelChatbot.py | 12 +++---- modules/features/chatBot/service.py | 47 +++++++++++++++++++++++++- modules/routes/routeChatbot.py | 37 ++++++-------------- 3 files changed, 62 insertions(+), 34 deletions(-) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py index 906757b7..9b6c723a 100644 --- a/modules/datamodels/datamodelChatbot.py +++ b/modules/datamodels/datamodelChatbot.py @@ -34,9 +34,9 @@ class ThreadSummary(BaseModel, ModelMixin): """Summary of a chat thread for list view""" thread_id: str = Field(..., description="Thread ID") - created_at: float = Field(..., description="Thread creation timestamp") - last_message: str = Field(..., description="Last message content") - message_count: int = Field(..., description="Total number of messages") + thread_name: str = Field(..., description="Thread name") + date_created: float = Field(..., description="Thread creation timestamp") + date_updated: float = Field(..., description="Thread last updated timestamp") class ThreadListResponse(BaseModel, ModelMixin): @@ -96,9 +96,9 @@ register_model_labels( {"en": "Thread Summary", "fr": "Résumé du fil"}, { "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, - "created_at": {"en": "Created At", "fr": "Créé le"}, - "last_message": {"en": "Last Message", "fr": "Dernier message"}, - "message_count": {"en": "Message Count", "fr": "Nombre de messages"}, + "thread_name": {"en": "Thread Name", "fr": "Nom du fil"}, + "date_created": {"en": "Date Created", "fr": "Date de création"}, + "date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"}, }, ) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 05df7f0f..3060081d 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -13,7 +13,11 @@ from modules.features.chatBot.utils.checkpointer import get_checkpointer from modules.features.chatBot.utils.toolRegistry import get_registry from modules.features.chatBot.utils import permissions from modules.features.chatBot.database import UserThreadMapping -from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse +from modules.datamodels.datamodelChatbot import ( + MessageItem, + ChatMessageResponse, + ThreadSummary, +) from modules.datamodels.datamodelUam import User from langchain_core.messages import HumanMessage, AIMessage @@ -22,6 +26,47 @@ from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) +async def get_all_threads_for_user( + *, + user: User, + session: AsyncSession, +) -> List[ThreadSummary]: + """Get all chat threads for a user. + + Args: + user: The current user. + session: The database session for querying. + + Returns: + List of ThreadSummary objects sorted by date_updated (newest first). + Returns empty list if no threads found. + """ + logger.info(f"Fetching all threads for user {user.id}") + + # Query all threads for this user, ordered by date_updated descending + stmt = ( + select(UserThreadMapping) + .where(UserThreadMapping.userId == user.id) + .order_by(UserThreadMapping.date_updated.desc()) + ) + result = await session.execute(stmt) + thread_mappings = result.scalars().all() + + # Convert to ThreadSummary objects + threads = [] + for mapping in thread_mappings: + thread_summary = ThreadSummary( + thread_id=mapping.threadId, + thread_name=mapping.threadName, + date_created=mapping.date_created.timestamp(), + date_updated=mapping.date_updated.timestamp(), + ) + threads.append(thread_summary) + + logger.info(f"Found {len(threads)} threads for user {user.id}") + return threads + + async def save_thread_for_user( *, thread_id: str, diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index e14f38c6..b7faec2d 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -143,40 +143,23 @@ async def post_chat_message( @router.get("/threads", response_model=ThreadListResponse) @limiter.limit("30/minute") async def get_all_threads( - *, request: Request, currentUser: User = Depends(getCurrentUser) + *, + request: Request, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), ) -> ThreadListResponse: """ Get all chat threads for the current user. - - This endpoint will later fetch from LangGraph's PostgreSQL checkpointer. """ try: - # Return dummy thread data - # In production, this will query LangGraph's checkpointer database - dummy_threads = [ - ThreadSummary( - thread_id="thread_001", - created_at=datetime.now().timestamp() - 86400, # 1 day ago - last_message="Hello, how can I help you?", - message_count=4, - ), - ThreadSummary( - thread_id="thread_002", - created_at=datetime.now().timestamp() - 3600, # 1 hour ago - last_message="Thank you for your help!", - message_count=8, - ), - ThreadSummary( - thread_id="thread_003", - created_at=datetime.now().timestamp() - 300, # 5 minutes ago - last_message="Can you explain this concept?", - message_count=2, - ), - ] + # Get all threads for the current user + threads = await chat_service.get_all_threads_for_user( + user=currentUser, session=session + ) - logger.info(f"User {currentUser.id} retrieved {len(dummy_threads)} threads") + logger.info(f"User {currentUser.id} retrieved {len(threads)} threads") - return ThreadListResponse(threads=dummy_threads) + return ThreadListResponse(threads=threads) except Exception as e: logger.error( From ed3920f9f91d3143ce3684cdb5e6117663d92e29 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 15:53:15 +0200 Subject: [PATCH 18/29] chore: update code styles --- modules/features/chatBot/database.py | 8 +++---- modules/features/chatBot/service.py | 32 ++++++++++++++-------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py index 50dc9bba..5473fa0f 100644 --- a/modules/features/chatBot/database.py +++ b/modules/features/chatBot/database.py @@ -24,11 +24,11 @@ class UserThreadMapping(Base): Thread_id is unique in the table. """ - __tablename__ = "userThreads" + __tablename__ = "user_threads" id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) - userId: Mapped[str] = mapped_column(String(255), nullable=False) - threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) - threadName: Mapped[str] = mapped_column(String(255), nullable=False) + user_id: Mapped[str] = mapped_column(String(255), nullable=False) + thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + thread_name: Mapped[str] = mapped_column(String(255), nullable=False) date_created: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 3060081d..ef3252e9 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -46,7 +46,7 @@ async def get_all_threads_for_user( # Query all threads for this user, ordered by date_updated descending stmt = ( select(UserThreadMapping) - .where(UserThreadMapping.userId == user.id) + .where(UserThreadMapping.user_id == user.id) .order_by(UserThreadMapping.date_updated.desc()) ) result = await session.execute(stmt) @@ -56,8 +56,8 @@ async def get_all_threads_for_user( threads = [] for mapping in thread_mappings: thread_summary = ThreadSummary( - thread_id=mapping.threadId, - thread_name=mapping.threadName, + thread_id=mapping.thread_id, + thread_name=mapping.thread_name, date_created=mapping.date_created.timestamp(), date_updated=mapping.date_updated.timestamp(), ) @@ -86,9 +86,9 @@ async def save_thread_for_user( # Create new mapping entry new_mapping = UserThreadMapping( - userId=user.id, - threadId=thread_id, - threadName=thread_name, + user_id=user.id, + thread_id=thread_id, + thread_name=thread_name, ) session.add(new_mapping) @@ -170,7 +170,7 @@ async def assure_thread_exists_and_belongs_to_user( ValueError: If the thread does not exist. """ # Query the database for the thread mapping - stmt = select(UserThreadMapping).where(UserThreadMapping.threadId == thread_id) + stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id) result = await session.execute(stmt) thread_mapping = result.scalar_one_or_none() @@ -180,10 +180,10 @@ async def assure_thread_exists_and_belongs_to_user( raise ValueError(f"Thread {thread_id} does not exist") # Check if thread belongs to the user - if thread_mapping.userId != user.id: + if thread_mapping.user_id != user.id: logger.warning( f"User {user.id} attempted to access thread {thread_id} " - f"belonging to user {thread_mapping.userId}" + f"belonging to user {thread_mapping.user_id}" ) raise PermissionError( f"You do not have permission to access thread {thread_id}" @@ -219,14 +219,14 @@ async def update_thread_name( ) # Update the thread name and date_updated - # Security check: WHERE clause includes both threadId AND userId + # Security check: WHERE clause includes both thread_id AND user_id stmt = ( update(UserThreadMapping) .where( - UserThreadMapping.threadId == thread_id, - UserThreadMapping.userId == user.id, + UserThreadMapping.thread_id == thread_id, + UserThreadMapping.user_id == user.id, ) - .values(threadName=new_thread_name, date_updated=datetime.now(timezone.utc)) + .values(thread_name=new_thread_name, date_updated=datetime.now(timezone.utc)) ) result = await session.execute(stmt) await session.commit() @@ -267,12 +267,12 @@ async def refresh_thread_date_updated( logger.info(f"Refreshing date_updated for thread {thread_id} for user {user.id}") # Update the date_updated timestamp - # Security check: WHERE clause includes both threadId AND userId + # Security check: WHERE clause includes both thread_id AND user_id stmt = ( update(UserThreadMapping) .where( - UserThreadMapping.threadId == thread_id, - UserThreadMapping.userId == user.id, + UserThreadMapping.thread_id == thread_id, + UserThreadMapping.user_id == user.id, ) .values(date_updated=datetime.now(timezone.utc)) ) From 85503fc66904984a45898960fd2d9268d78e7735 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 16:14:17 +0200 Subject: [PATCH 19/29] feat: implement get thread details endpoint --- modules/datamodels/datamodelChatbot.py | 6 +- modules/features/chatBot/service.py | 126 ++++++++++++++++++++++++- modules/routes/routeChatbot.py | 60 +++++------- 3 files changed, 154 insertions(+), 38 deletions(-) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py index 9b6c723a..243eb02c 100644 --- a/modules/datamodels/datamodelChatbot.py +++ b/modules/datamodels/datamodelChatbot.py @@ -49,7 +49,8 @@ class ThreadDetail(BaseModel, ModelMixin): """Detailed view of a single thread""" thread_id: str = Field(..., description="Thread ID") - created_at: float = Field(..., description="Thread creation timestamp") + date_created: float = Field(..., description="Thread creation timestamp") + date_updated: float = Field(..., description="Thread last updated timestamp") messages: List[MessageItem] = Field( ..., description="All messages in chronological order" ) @@ -115,7 +116,8 @@ register_model_labels( {"en": "Thread Detail", "fr": "Détail du fil"}, { "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, - "created_at": {"en": "Created At", "fr": "Créé le"}, + "date_created": {"en": "Date Created", "fr": "Date de création"}, + "date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"}, "messages": {"en": "Messages", "fr": "Messages"}, }, ) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index ef3252e9..8fb35106 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -17,10 +17,11 @@ from modules.datamodels.datamodelChatbot import ( MessageItem, ChatMessageResponse, ThreadSummary, + ThreadDetail, ) from modules.datamodels.datamodelUam import User -from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) @@ -491,3 +492,126 @@ async def post_message_stream( ) + "\n\n" ) + + +async def get_thread_messages_from_langgraph( + *, + thread_id: str, + app, +) -> List[dict]: + """Retrieve and format messages from LangGraph checkpointer. + + Args: + thread_id: The unique identifier for the chat thread. + app: The compiled LangGraph app with checkpointer. + + Returns: + List of message dicts with role, content, and timestamp. + """ + ROLE_MAP = {"human": "user", "ai": "assistant"} + + cfg = {"configurable": {"thread_id": thread_id}} + state = await app.aget_state(cfg) + + messages = [] + for msg in state.values.get("messages", []): + # Skip system and tool messages - only include user and assistant + if msg.type not in ["human", "ai"]: + continue + + # Convert content to string if needed + content = msg.content if isinstance(msg.content, str) else str(msg.content) + + messages.append( + { + "role": ROLE_MAP.get(msg.type, msg.type), + "content": content, + "timestamp": 0.0, + } + ) + + return messages + + +async def get_thread_detail_for_user( + *, + thread_id: str, + user: User, + session: AsyncSession, +) -> ThreadDetail: + """Get detailed thread information with message history from LangGraph. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for querying. + + Returns: + ThreadDetail object with thread metadata and message history. + + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the thread does not exist. + """ + logger.info(f"Getting thread detail for thread {thread_id} for user {user.id}") + + # Verify thread exists and belongs to user + await assure_thread_exists_and_belongs_to_user( + thread_id=thread_id, user=user, session=session + ) + + # Get thread metadata from database + stmt = select(UserThreadMapping).where(UserThreadMapping.thread_id == thread_id) + result = await session.execute(stmt) + thread_mapping = result.scalar_one() + + # Build the chatbot app to access LangGraph state + # Use same approach as post_message for consistency + tool_ids = permissions.get_chatbot_tools(user_id=user.id) + if not tool_ids: + raise ValueError("User does not have permission to use any chatbot tools") + + model_name = permissions.get_chatbot_model(user_id=user.id) + system_prompt = permissions.get_system_prompt(user_id=user.id) + + # Get tools from registry + registry = get_registry() + tools = registry.get_tool_instances(tool_ids=tool_ids) + + # Get model and checkpointer + model = get_langchain_model(model_name=model_name) + checkpointer = get_checkpointer() + + # Get context window size from config + context_window_size = int( + APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) + ) + + # Create chatbot instance + chatbot = await Chatbot.create( + model=model, + memory=checkpointer, + system_prompt=system_prompt, + tools=tools, + context_window_size=context_window_size, + ) + + # Get messages from LangGraph checkpointer + message_dicts = await get_thread_messages_from_langgraph( + thread_id=thread_id, app=chatbot.app + ) + + # Convert to MessageItem objects + messages = [MessageItem(**m) for m in message_dicts] + + logger.info( + f"Retrieved thread {thread_id} with {len(messages)} messages for user {user.id}" + ) + + # Return ThreadDetail + return ThreadDetail( + thread_id=thread_id, + date_created=thread_mapping.date_created.timestamp(), + date_updated=thread_mapping.date_updated.timestamp(), + messages=messages, + ) diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index b7faec2d..cfc8b620 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -174,47 +174,37 @@ async def get_all_threads( @router.get("/threads/{thread_id}", response_model=ThreadDetail) @limiter.limit("30/minute") async def get_thread_by_id( - *, request: Request, thread_id: str, currentUser: User = Depends(getCurrentUser) + *, + request: Request, + thread_id: str, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), ) -> ThreadDetail: """ - Get a specific chat thread with all its messages. - - This endpoint will later fetch from LangGraph's PostgreSQL checkpointer. + Get a specific chat thread with all its messages from LangGraph checkpointer. """ try: - # Return dummy thread detail - # In production, this will query LangGraph's checkpointer for the specific thread - current_time = datetime.now().timestamp() - - dummy_messages = [ - MessageItem( - role="user", - content="Hello! I need help with Python.", - timestamp=current_time - 120, - ), - MessageItem( - role="assistant", - content="Hello! I'd be happy to help you with Python. What would you like to know?", - timestamp=current_time - 119, - ), - MessageItem( - role="user", - content="How do I use list comprehensions?", - timestamp=current_time - 60, - ), - MessageItem( - role="assistant", - content="List comprehensions are a concise way to create lists. Here's an example: [x**2 for x in range(10)]", - timestamp=current_time - 59, - ), - ] - - logger.info(f"User {currentUser.id} retrieved thread {thread_id}") - - return ThreadDetail( - thread_id=thread_id, created_at=current_time - 120, messages=dummy_messages + thread_detail = await chat_service.get_thread_detail_for_user( + thread_id=thread_id, + user=currentUser, + session=session, ) + logger.info(f"User {currentUser.id} retrieved thread {thread_id}") + return thread_detail + + except ValueError as e: + logger.error(f"Thread not found: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e) or "Thread not found", + ) + except PermissionError as e: + logger.error(f"Permission denied: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) except Exception as e: logger.error( f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}", From a0a87e2e3e6c43e02d83786d2f161fb6792718c9 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 16:30:44 +0200 Subject: [PATCH 20/29] fix: remove timestamps --- modules/datamodels/datamodelChatbot.py | 2 -- modules/features/chatBot/service.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py index 243eb02c..04ac0788 100644 --- a/modules/datamodels/datamodelChatbot.py +++ b/modules/datamodels/datamodelChatbot.py @@ -11,7 +11,6 @@ class MessageItem(BaseModel, ModelMixin): role: str = Field(..., description="Message role (user or assistant)") content: str = Field(..., description="Message content") - timestamp: float = Field(..., description="Message timestamp (Unix timestamp)") class ChatMessageRequest(BaseModel, ModelMixin): @@ -70,7 +69,6 @@ register_model_labels( { "role": {"en": "Role", "fr": "Rôle"}, "content": {"en": "Content", "fr": "Contenu"}, - "timestamp": {"en": "Timestamp", "fr": "Horodatage"}, }, ) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 8fb35106..e9a9b360 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -365,7 +365,6 @@ async def post_message( item = MessageItem( role=role, content=msg.content.strip(), - timestamp=0.0, # TODO: Add proper timestamp handling ) messages.append(item) @@ -457,7 +456,6 @@ async def post_message_stream( MessageItem( role=role, content=content, - timestamp=0.0, # TODO: Add proper timestamp handling ) ) @@ -526,7 +524,6 @@ async def get_thread_messages_from_langgraph( { "role": ROLE_MAP.get(msg.type, msg.type), "content": content, - "timestamp": 0.0, } ) From bfc07ee0b1902feb1fedf5449aeba2e7d327055c Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 16:50:54 +0200 Subject: [PATCH 21/29] feat: implement rename, delete thread endpoints --- modules/datamodels/datamodelChatbot.py | 14 ++++ modules/features/chatBot/service.py | 90 +++++++++++++++++++++++++- modules/routes/routeChatbot.py | 81 ++++++++++++++++++++--- 3 files changed, 175 insertions(+), 10 deletions(-) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py index 04ac0788..0ced3503 100644 --- a/modules/datamodels/datamodelChatbot.py +++ b/modules/datamodels/datamodelChatbot.py @@ -55,6 +55,12 @@ class ThreadDetail(BaseModel, ModelMixin): ) +class RenameThreadRequest(BaseModel, ModelMixin): + """Request model for renaming a thread""" + + new_name: str = Field(..., description="New name for the thread") + + class DeleteResponse(BaseModel, ModelMixin): """Response model for delete operations""" @@ -120,6 +126,14 @@ register_model_labels( }, ) +register_model_labels( + "RenameThreadRequest", + {"en": "Rename Thread Request", "fr": "Demande de renommage de fil"}, + { + "new_name": {"en": "New Name", "fr": "Nouveau nom"}, + }, +) + register_model_labels( "DeleteResponse", {"en": "Delete Response", "fr": "Réponse de suppression"}, diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index e9a9b360..bcf11abf 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -5,7 +5,7 @@ import logging from datetime import datetime, timezone from typing import AsyncIterator, List, Optional -from sqlalchemy import select, update +from sqlalchemy import select, update, delete from sqlalchemy.ext.asyncio import AsyncSession from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model @@ -612,3 +612,91 @@ async def get_thread_detail_for_user( date_updated=thread_mapping.date_updated.timestamp(), messages=messages, ) + + +async def delete_thread_for_user( + *, + thread_id: str, + user: User, + session: AsyncSession, +) -> None: + """Delete a chat thread for a user from both LangGraph and the database. + + Args: + thread_id: The unique identifier for the chat thread. + user: The current user. + session: The database session for deleting. + + Raises: + PermissionError: If the thread does not belong to the user. + ValueError: If the thread does not exist. + """ + logger.info(f"Deleting thread {thread_id} for user {user.id}") + + # Verify thread exists and belongs to user + await assure_thread_exists_and_belongs_to_user( + thread_id=thread_id, user=user, session=session + ) + + # Build the chatbot app to access the checkpointer + tool_ids = permissions.get_chatbot_tools(user_id=user.id) + if not tool_ids: + raise ValueError("User does not have permission to use any chatbot tools") + + model_name = permissions.get_chatbot_model(user_id=user.id) + system_prompt = permissions.get_system_prompt(user_id=user.id) + + # Get tools from registry + registry = get_registry() + tools = registry.get_tool_instances(tool_ids=tool_ids) + + # Get model and checkpointer + model = get_langchain_model(model_name=model_name) + checkpointer = get_checkpointer() + + # Get context window size from config + context_window_size = int( + APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) + ) + + # Create chatbot instance + chatbot = await Chatbot.create( + model=model, + memory=checkpointer, + system_prompt=system_prompt, + tools=tools, + context_window_size=context_window_size, + ) + + # Delete from LangGraph checkpointer + try: + await chatbot.app.checkpointer.adelete_thread(thread_id) + logger.info(f"Deleted thread {thread_id} from LangGraph checkpointer") + except Exception as e: + logger.error( + f"Failed to delete thread {thread_id} from LangGraph: {type(e).__name__}: {str(e)}", + exc_info=True, + ) + raise ValueError( + f"Failed to delete thread from LangGraph: {type(e).__name__}: {str(e)}" + ) + + # Delete from database + stmt = delete(UserThreadMapping).where( + UserThreadMapping.thread_id == thread_id, + UserThreadMapping.user_id == user.id, + ) + result = await session.execute(stmt) + await session.commit() + + # Check if any rows were deleted + if result.rowcount == 0: + logger.warning( + f"Failed to delete thread {thread_id} from database for user {user.id} - " + "thread does not exist or does not belong to user" + ) + raise ValueError( + f"Thread {thread_id} does not exist or you do not have permission to access it" + ) + + logger.info(f"Successfully deleted thread {thread_id} for user {user.id}") diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index cfc8b620..1c2b5ed6 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -20,6 +20,7 @@ from modules.datamodels.datamodelChatbot import ( ThreadSummary, ThreadListResponse, ThreadDetail, + RenameThreadRequest, DeleteResponse, ) from modules.security.auth import getCurrentUser, limiter @@ -216,29 +217,91 @@ async def get_thread_by_id( ) +@router.patch("/threads/{thread_id}", response_model=DeleteResponse) +@limiter.limit("30/minute") +async def rename_thread( + *, + request: Request, + thread_id: str, + rename_request: RenameThreadRequest, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), +) -> DeleteResponse: + """ + Rename a chat thread. + """ + try: + await chat_service.update_thread_name( + thread_id=thread_id, + user=currentUser, + new_thread_name=rename_request.new_name, + session=session, + ) + + logger.info( + f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'" + ) + + return DeleteResponse( + message=f"Thread {thread_id} successfully renamed", + thread_id=thread_id, + ) + + except ValueError as e: + logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e) or "Thread not found or permission denied", + ) + except Exception as e: + logger.error( + f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}", + exc_info=True, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}", + ) + + @router.delete("/threads/{thread_id}", response_model=DeleteResponse) @limiter.limit("10/minute") async def delete_thread( - *, request: Request, thread_id: str, currentUser: User = Depends(getCurrentUser) + *, + request: Request, + thread_id: str, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), ) -> DeleteResponse: """ - Delete a chat thread and all its associated data. - - This endpoint will later delete from LangGraph's PostgreSQL checkpointer. + Delete a chat thread and all its associated data from both LangGraph and database. """ try: - # In production, this will: - # 1. Verify the thread belongs to the current user - # 2. Delete the thread from LangGraph's checkpointer - # 3. Clean up any associated data + await chat_service.delete_thread_for_user( + thread_id=thread_id, + user=currentUser, + session=session, + ) logger.info(f"User {currentUser.id} deleted thread {thread_id}") return DeleteResponse( - message=f"Thread {thread_id} successfully deleted (dummy response)", + message=f"Thread {thread_id} successfully deleted", thread_id=thread_id, ) + except ValueError as e: + logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e) or "Thread not found or permission denied", + ) + except PermissionError as e: + logger.error(f"Permission denied: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) except Exception as e: logger.error( f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}", From 9bf454823b12d6eae2d96b8d321ebd0d01bf1c13 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 17:14:58 +0200 Subject: [PATCH 22/29] chore: use minimal langgraph app to read and delete checkpointer history --- modules/features/chatBot/service.py | 120 ++++++++++++---------------- 1 file changed, 51 insertions(+), 69 deletions(-) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index bcf11abf..9cbb3ccf 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -22,6 +22,7 @@ from modules.datamodels.datamodelChatbot import ( from modules.datamodels.datamodelUam import User from langchain_core.messages import HumanMessage, AIMessage, BaseMessage +from langgraph.graph import StateGraph, MessagesState, START, END from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) @@ -492,22 +493,65 @@ async def post_message_stream( ) +# Module-level singleton for minimal app used to read thread state +_MINIMAL_APP = None + + +def _build_minimal_app(*, checkpointer): + """Build a minimal LangGraph app for reading thread state. + + This creates a valid graph with a no-op node that we never actually run. + LangGraph requires a valid graph structure (with edges from START) to compile, + even though we only use it to call aget_state() to read from the checkpointer. + + Args: + checkpointer: The checkpointer to attach to the graph. + + Returns: + A compiled StateGraph that can be used to read thread state. + """ + graph = StateGraph(MessagesState) + + # No-op node that returns the state unchanged + def noop(state: dict) -> dict: + return state + + graph.add_node("noop", noop) + graph.add_edge(START, "noop") + graph.add_edge("noop", END) + + return graph.compile(checkpointer=checkpointer) + + +def _get_minimal_app(): + """Get the module-level singleton minimal app. + + Returns: + The cached minimal app, building it on first access. + """ + global _MINIMAL_APP + if _MINIMAL_APP is None: + _MINIMAL_APP = _build_minimal_app(checkpointer=get_checkpointer()) + return _MINIMAL_APP + + async def get_thread_messages_from_langgraph( *, thread_id: str, - app, ) -> List[dict]: """Retrieve and format messages from LangGraph checkpointer. Args: thread_id: The unique identifier for the chat thread. - app: The compiled LangGraph app with checkpointer. Returns: - List of message dicts with role, content, and timestamp. + List of message dicts with role and content. """ ROLE_MAP = {"human": "user", "ai": "assistant"} + # Get the minimal app (singleton, built once) + app = _get_minimal_app() + cfg = {"configurable": {"thread_id": thread_id}} state = await app.aget_state(cfg) @@ -562,41 +606,8 @@ async def get_thread_detail_for_user( result = await session.execute(stmt) thread_mapping = result.scalar_one() - # Build the chatbot app to access LangGraph state - # Use same approach as post_message for consistency - tool_ids = permissions.get_chatbot_tools(user_id=user.id) - if not tool_ids: - raise ValueError("User does not have permission to use any chatbot tools") - - model_name = permissions.get_chatbot_model(user_id=user.id) - system_prompt = permissions.get_system_prompt(user_id=user.id) - - # Get tools from registry - registry = get_registry() - tools = registry.get_tool_instances(tool_ids=tool_ids) - - # Get model and checkpointer - model = get_langchain_model(model_name=model_name) - checkpointer = get_checkpointer() - - # Get context window size from config - context_window_size = int( - APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) - ) - - # Create chatbot instance - chatbot = await Chatbot.create( - model=model, - memory=checkpointer, - system_prompt=system_prompt, - tools=tools, - context_window_size=context_window_size, - ) - - # Get messages from LangGraph checkpointer - message_dicts = await get_thread_messages_from_langgraph( - thread_id=thread_id, app=chatbot.app - ) + # Get messages from LangGraph checkpointer (optimized - no full chatbot needed) + message_dicts = await get_thread_messages_from_langgraph(thread_id=thread_id) # Convert to MessageItem objects messages = [MessageItem(**m) for m in message_dicts] @@ -638,39 +649,10 @@ async def delete_thread_for_user( thread_id=thread_id, user=user, session=session ) - # Build the chatbot app to access the checkpointer - tool_ids = permissions.get_chatbot_tools(user_id=user.id) - if not tool_ids: - raise ValueError("User does not have permission to use any chatbot tools") - - model_name = permissions.get_chatbot_model(user_id=user.id) - system_prompt = permissions.get_system_prompt(user_id=user.id) - - # Get tools from registry - registry = get_registry() - tools = registry.get_tool_instances(tool_ids=tool_ids) - - # Get model and checkpointer - model = get_langchain_model(model_name=model_name) + # Delete from LangGraph checkpointer (optimized - no app/tools/model needed) checkpointer = get_checkpointer() - - # Get context window size from config - context_window_size = int( - APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) - ) - - # Create chatbot instance - chatbot = await Chatbot.create( - model=model, - memory=checkpointer, - system_prompt=system_prompt, - tools=tools, - context_window_size=context_window_size, - ) - - # Delete from LangGraph checkpointer try: - await chatbot.app.checkpointer.adelete_thread(thread_id) + await checkpointer.adelete_thread(thread_id) logger.info(f"Deleted thread {thread_id} from LangGraph checkpointer") except Exception as e: logger.error( From 179b848ecbd98cc5e68892535b6a7076a02cd03e Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Wed, 8 Oct 2025 17:22:07 +0200 Subject: [PATCH 23/29] fix: hide tool calls from chatbot history --- modules/features/chatBot/service.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 9cbb3ccf..510018c3 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -561,13 +561,14 @@ async def get_thread_messages_from_langgraph( if msg.type not in ["human", "ai"]: continue - # Convert content to string if needed - content = msg.content if isinstance(msg.content, str) else str(msg.content) + # Skip messages with non-string content (e.g., tool calls) + if not isinstance(msg.content, str): + continue messages.append( { "role": ROLE_MAP.get(msg.type, msg.type), - "content": content, + "content": msg.content, } ) From 8538821d0c08389ecdcdc153df41fa20809e1d90 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Thu, 9 Oct 2025 10:54:11 +0200 Subject: [PATCH 24/29] fix: pydantic schema issue --- modules/datamodels/datamodelChat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index 31e65004..198360bf 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -1,6 +1,7 @@ """Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument.""" from typing import List, Dict, Any, Optional +from enum import Enum from pydantic import BaseModel, Field from modules.shared.attributeUtils import register_model_labels, ModelMixin from modules.shared.timezoneUtils import get_utc_timestamp @@ -574,7 +575,7 @@ register_model_labels( ) -class TaskStatus(str): +class TaskStatus(str, Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" From 0c5d0f957fd6174ee2c673fdaf5c006af5c595d2 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Thu, 9 Oct 2025 11:39:04 +0200 Subject: [PATCH 25/29] feat: add tools and tools permissions to db with auto sync --- app.py | 8 ++ modules/features/chatBot/database.py | 141 ++++++++++++++++++++++++++- 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index b75b59e8..384d6bfa 100644 --- a/app.py +++ b/app.py @@ -275,6 +275,14 @@ async def lifespan(app: FastAPI): # NOTE: Might need Alembic migrations in the future await init_chatbot_models(engine) + # --- Sync tools from registry to database --- + from modules.features.chatBot.database import sync_tools_from_registry + + async with SessionLocal() as session: + await sync_tools_from_registry(session) + await session.commit() + logger.info("Tools synced from registry to database") + # --- Initialize LangGraph checkpointer --- from modules.features.chatBot.utils.checkpointer import ( diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py index 5473fa0f..1dc4ebe6 100644 --- a/modules/features/chatBot/database.py +++ b/modules/features/chatBot/database.py @@ -5,13 +5,71 @@ from datetime import datetime, timezone from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import String, Uuid, DateTime +from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint class Base(DeclarativeBase): pass +# Tools Table +class Tool(Base): + """Available chatbot tools. + + Stores information about all available tools that can be assigned to users. + Each tool has a unique tool_id that corresponds to the registry tool_id. + """ + + __tablename__ = "tools" + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + label: Mapped[str] = mapped_column(String(255), nullable=False) + category: Mapped[str] = mapped_column(String(50), nullable=False) + description: Mapped[str] = mapped_column(String(1000), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + date_created: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + date_updated: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + + +# User-Tool Mapping Table +class UserToolMapping(Base): + """Mapping of users to their available tools. + + Many-to-many relationship between users and tools. + - One user can have multiple tools + - One tool can be assigned to multiple users + + The combination of user_id and tool_id is unique. + """ + + __tablename__ = "user_tools" + __table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),) + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + user_id: Mapped[str] = mapped_column(String(255), nullable=False) + tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + date_granted: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + date_updated: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + + # User Thread Mapping Table class UserThreadMapping(Base): """Mapping of users to their chat threads. @@ -56,3 +114,84 @@ async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]: async def init_models(engine) -> None: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + + +async def sync_tools_from_registry(session: AsyncSession) -> None: + """Sync tools from tool registry to database. + + This function: + - Adds new tools from the registry to the database + - Updates existing tools with current registry information + - Marks tools not present in the registry as inactive + + Should be called on application startup after database initialization. + + Args: + session: Active database session + """ + import logging + from sqlalchemy import select + + from modules.features.chatBot.utils.toolRegistry import get_registry + + logger = logging.getLogger(__name__) + logger.info("Syncing tools from registry to database...") + + # Get all tools from the registry + registry = get_registry() + registry_tools = registry.get_all_tools() + + # Create a set of tool_ids from the registry + registry_tool_ids = {tool.tool_id for tool in registry_tools} + + logger.info(f"Found {len(registry_tools)} tools in registry") + + # Get all existing tools from the database + result = await session.execute(select(Tool)) + db_tools = result.scalars().all() + db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools} + + logger.info(f"Found {len(db_tools)} tools in database") + + # Track changes + added_count = 0 + updated_count = 0 + deactivated_count = 0 + + # Sync tools from registry to database + for registry_tool in registry_tools: + if registry_tool.tool_id in db_tools_by_tool_id: + # Tool exists - update it + # Preserve label and description (user-editable fields) + db_tool = db_tools_by_tool_id[registry_tool.tool_id] + db_tool.name = registry_tool.name + db_tool.category = registry_tool.category + db_tool.is_active = True + db_tool.date_updated = datetime.now(timezone.utc) + updated_count += 1 + logger.debug(f"Updated tool: {registry_tool.tool_id}") + else: + # Tool doesn't exist - create it + new_tool = Tool( + tool_id=registry_tool.tool_id, + name=registry_tool.name, + label=registry_tool.tool_id, # Use tool_id as label per spec + category=registry_tool.category, + description=registry_tool.description or "", + is_active=True, + ) + session.add(new_tool) + added_count += 1 + logger.debug(f"Added new tool: {registry_tool.tool_id}") + + # Mark tools not in registry as inactive + for db_tool in db_tools: + if db_tool.tool_id not in registry_tool_ids and db_tool.is_active: + db_tool.is_active = False + db_tool.date_updated = datetime.now(timezone.utc) + deactivated_count += 1 + logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}") + + logger.info( + f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated" + ) From 4b75ffd70cd9ce1ae9aba4cafa10f1133157a171 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Thu, 9 Oct 2025 15:30:15 +0200 Subject: [PATCH 26/29] feat: add admin endpoints for chatbot tools --- modules/datamodels/datamodelChatbot.py | 68 ++++++++ modules/features/chatBot/service.py | 207 +++++++++++++++++++++++ modules/routes/routeChatbot.py | 225 ++++++++++++++++++++++++- 3 files changed, 499 insertions(+), 1 deletion(-) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py index 0ced3503..d4bb2f3a 100644 --- a/modules/datamodels/datamodelChatbot.py +++ b/modules/datamodels/datamodelChatbot.py @@ -68,6 +68,74 @@ class DeleteResponse(BaseModel, ModelMixin): thread_id: str = Field(..., description="Deleted thread ID") +# Tool Management Models +class ToolInfo(BaseModel, ModelMixin): + """Information about a chatbot tool""" + + id: str = Field(..., description="Tool UUID") + tool_id: str = Field( + ..., description="Tool identifier (e.g., 'shared.tavily_search')" + ) + name: str = Field(..., description="Tool function name") + label: str = Field(..., description="Display label for the tool") + category: str = Field(..., description="Tool category (shared or customer)") + description: str = Field(..., description="Tool description") + is_active: bool = Field(..., description="Whether the tool is active") + date_created: float = Field(..., description="Creation timestamp") + date_updated: float = Field(..., description="Last update timestamp") + + +class ToolListResponse(BaseModel, ModelMixin): + """Response model for listing all tools""" + + tools: List[ToolInfo] = Field(..., description="List of available tools") + + +class GrantToolRequest(BaseModel, ModelMixin): + """Request model for granting a tool to a user""" + + user_id: str = Field(..., description="User ID to grant the tool to") + tool_id: str = Field(..., description="Tool UUID from tools table") + + +class GrantToolResponse(BaseModel, ModelMixin): + """Response model after granting a tool""" + + message: str = Field(..., description="Confirmation message") + user_id: str = Field(..., description="User ID") + tool_id: str = Field(..., description="Tool UUID") + + +class RevokeToolRequest(BaseModel, ModelMixin): + """Request model for revoking a tool from a user""" + + user_id: str = Field(..., description="User ID to revoke the tool from") + tool_id: str = Field(..., description="Tool UUID from tools table") + + +class RevokeToolResponse(BaseModel, ModelMixin): + """Response model after revoking a tool""" + + message: str = Field(..., description="Confirmation message") + user_id: str = Field(..., description="User ID") + tool_id: str = Field(..., description="Tool UUID") + + +class UpdateToolRequest(BaseModel, ModelMixin): + """Request model for updating a tool's label and description""" + + label: Optional[str] = Field(None, description="New label for the tool") + description: Optional[str] = Field(None, description="New description for the tool") + + +class UpdateToolResponse(BaseModel, ModelMixin): + """Response model after updating a tool""" + + message: str = Field(..., description="Confirmation message") + tool_id: str = Field(..., description="Tool UUID") + updated_fields: List[str] = Field(..., description="List of updated field names") + + # Register model labels for internationalization register_model_labels( "MessageItem", diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 510018c3..9c2f7aa6 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -683,3 +683,210 @@ async def delete_thread_for_user( ) logger.info(f"Successfully deleted thread {thread_id} for user {user.id}") + + +# Tool Management Functions + + +async def get_all_tools(*, session: AsyncSession) -> List[dict]: + """Get all tools from the database. + + Args: + session: The database session for querying. + + Returns: + List of tool dictionaries with all tool information. + """ + from modules.features.chatBot.database import Tool + + logger.info("Fetching all tools from database") + + stmt = select(Tool).order_by(Tool.category, Tool.name) + result = await session.execute(stmt) + tools = result.scalars().all() + + tool_list = [] + for tool in tools: + tool_dict = { + "id": str(tool.id), + "tool_id": tool.tool_id, + "name": tool.name, + "label": tool.label, + "category": tool.category, + "description": tool.description, + "is_active": tool.is_active, + "date_created": tool.date_created.timestamp(), + "date_updated": tool.date_updated.timestamp(), + } + tool_list.append(tool_dict) + + logger.info(f"Retrieved {len(tool_list)} tools from database") + return tool_list + + +async def grant_tool_to_user( + *, user_id: str, tool_id: str, session: AsyncSession +) -> None: + """Grant a tool to a user. + + Args: + user_id: The user ID to grant the tool to. + tool_id: The tool UUID from the tools table. + session: The database session for querying/updating. + + Raises: + ValueError: If the tool doesn't exist, is not active, or user already has the tool. + """ + from modules.features.chatBot.database import Tool, UserToolMapping + import uuid + + logger.info(f"Granting tool {tool_id} to user {user_id}") + + # Convert tool_id string to UUID + try: + tool_uuid = uuid.UUID(tool_id) + except ValueError: + raise ValueError(f"Invalid tool ID format: {tool_id}") + + # Check if tool exists and is active + stmt = select(Tool).where(Tool.id == tool_uuid) + result = await session.execute(stmt) + tool = result.scalar_one_or_none() + + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + if not tool.is_active: + raise ValueError( + f"Cannot grant inactive tool '{tool.label}' (tool_id: {tool.tool_id}). " + f"Please activate the tool first before granting it to users." + ) + + # Check if user already has this tool + stmt = select(UserToolMapping).where( + UserToolMapping.user_id == user_id, UserToolMapping.tool_id == tool_uuid + ) + result = await session.execute(stmt) + existing_mapping = result.scalar_one_or_none() + + if existing_mapping is not None: + raise ValueError( + f"User {user_id} already has access to tool '{tool.label}' (tool_id: {tool.tool_id})" + ) + + # Create new mapping + new_mapping = UserToolMapping( + user_id=user_id, + tool_id=tool_uuid, + is_active=True, + ) + + session.add(new_mapping) + await session.commit() + + logger.info(f"Successfully granted tool {tool_id} ({tool.label}) to user {user_id}") + + +async def revoke_tool_from_user( + *, user_id: str, tool_id: str, session: AsyncSession +) -> None: + """Revoke a tool from a user by deleting the mapping. + + Args: + user_id: The user ID to revoke the tool from. + tool_id: The tool UUID from the tools table. + session: The database session for deleting. + + Raises: + ValueError: If the mapping doesn't exist. + """ + from modules.features.chatBot.database import UserToolMapping + import uuid + + logger.info(f"Revoking tool {tool_id} from user {user_id}") + + # Convert tool_id string to UUID + try: + tool_uuid = uuid.UUID(tool_id) + except ValueError: + raise ValueError(f"Invalid tool ID format: {tool_id}") + + # Delete the mapping + stmt = delete(UserToolMapping).where( + UserToolMapping.user_id == user_id, UserToolMapping.tool_id == tool_uuid + ) + result = await session.execute(stmt) + await session.commit() + + # Check if any rows were deleted + if result.rowcount == 0: + raise ValueError( + f"User {user_id} does not have access to tool {tool_id}, or the mapping does not exist" + ) + + logger.info(f"Successfully revoked tool {tool_id} from user {user_id}") + + +async def update_tool( + *, + tool_id: str, + label: Optional[str], + description: Optional[str], + session: AsyncSession, +) -> List[str]: + """Update a tool's label and/or description. + + Args: + tool_id: The tool UUID to update. + label: Optional new label for the tool. + description: Optional new description for the tool. + session: The database session for updating. + + Returns: + List of updated field names. + + Raises: + ValueError: If the tool doesn't exist or no fields provided to update. + """ + from modules.features.chatBot.database import Tool + import uuid + + logger.info(f"Updating tool {tool_id}") + + # Validate that at least one field is provided + if label is None and description is None: + raise ValueError("At least one field (label or description) must be provided") + + # Convert tool_id string to UUID + try: + tool_uuid = uuid.UUID(tool_id) + except ValueError: + raise ValueError(f"Invalid tool ID format: {tool_id}") + + # Check if tool exists + stmt = select(Tool).where(Tool.id == tool_uuid) + result = await session.execute(stmt) + tool = result.scalar_one_or_none() + + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + # Build update values + update_values = {"date_updated": datetime.now(timezone.utc)} + updated_fields = [] + + if label is not None: + update_values["label"] = label + updated_fields.append("label") + + if description is not None: + update_values["description"] = description + updated_fields.append("description") + + # Update the tool + stmt = update(Tool).where(Tool.id == tool_uuid).values(**update_values) + await session.execute(stmt) + await session.commit() + + logger.info(f"Successfully updated tool {tool_id}, fields: {updated_fields}") + return updated_fields diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 1c2b5ed6..b9fbc092 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -12,7 +12,7 @@ from modules.features.chatBot.database import get_async_db_session from modules.features.chatBot.service import ( get_or_create_thread_for_user, ) -from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelUam import User, UserPrivilege from modules.datamodels.datamodelChatbot import ( ChatMessageRequest, MessageItem, @@ -22,6 +22,14 @@ from modules.datamodels.datamodelChatbot import ( ThreadDetail, RenameThreadRequest, DeleteResponse, + ToolListResponse, + ToolInfo, + GrantToolRequest, + GrantToolResponse, + RevokeToolRequest, + RevokeToolResponse, + UpdateToolRequest, + UpdateToolResponse, ) from modules.security.auth import getCurrentUser, limiter from modules.features.chatBot import service as chat_service @@ -311,3 +319,218 @@ async def delete_thread( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}", ) + + +# Tool Management Endpoints + + +@router.get("/tools", response_model=ToolListResponse) +@limiter.limit("30/minute") +async def get_all_tools( + *, + request: Request, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), +) -> ToolListResponse: + """ + Get all available chatbot tools. + Only accessible to system administrators. + """ + try: + # Check SYSADMIN permission + if currentUser.privilege != UserPrivilege.SYSADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only system administrators can view tools", + ) + + # Get all tools from service + tools_data = await chat_service.get_all_tools(session=session) + + # Convert to ToolInfo objects + tools = [ToolInfo(**tool) for tool in tools_data] + + logger.info(f"User {currentUser.id} retrieved {len(tools)} tools") + + return ToolListResponse(tools=tools) + + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}", + ) + + +@router.post("/tools/grant", response_model=GrantToolResponse) +@limiter.limit("10/minute") +async def grant_tool_to_user( + *, + request: Request, + grant_request: GrantToolRequest, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), +) -> GrantToolResponse: + """ + Grant a tool to a user. + Only accessible to system administrators. + """ + try: + # Check SYSADMIN permission + if currentUser.privilege != UserPrivilege.SYSADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only system administrators can grant tools", + ) + + # Grant the tool + await chat_service.grant_tool_to_user( + user_id=grant_request.user_id, + tool_id=grant_request.tool_id, + session=session, + ) + + logger.info( + f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}" + ) + + return GrantToolResponse( + message=f"Tool successfully granted to user {grant_request.user_id}", + user_id=grant_request.user_id, + tool_id=grant_request.tool_id, + ) + + except ValueError as e: + logger.error(f"Validation error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) or "Invalid request", + ) + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}", + ) + + +@router.delete("/tools/revoke", response_model=RevokeToolResponse) +@limiter.limit("10/minute") +async def revoke_tool_from_user( + *, + request: Request, + revoke_request: RevokeToolRequest, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), +) -> RevokeToolResponse: + """ + Revoke a tool from a user. + Only accessible to system administrators. + """ + try: + # Check SYSADMIN permission + if currentUser.privilege != UserPrivilege.SYSADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only system administrators can revoke tools", + ) + + # Revoke the tool + await chat_service.revoke_tool_from_user( + user_id=revoke_request.user_id, + tool_id=revoke_request.tool_id, + session=session, + ) + + logger.info( + f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}" + ) + + return RevokeToolResponse( + message=f"Tool successfully revoked from user {revoke_request.user_id}", + user_id=revoke_request.user_id, + tool_id=revoke_request.tool_id, + ) + + except ValueError as e: + logger.error(f"Validation error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) or "Invalid request", + ) + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}", + ) + + +@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse) +@limiter.limit("10/minute") +async def update_tool( + *, + request: Request, + tool_id: str, + update_request: UpdateToolRequest, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), +) -> UpdateToolResponse: + """ + Update a tool's label and/or description. + Only accessible to system administrators. + """ + try: + # Check SYSADMIN permission + if currentUser.privilege != UserPrivilege.SYSADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only system administrators can update tools", + ) + + # Update the tool + updated_fields = await chat_service.update_tool( + tool_id=tool_id, + label=update_request.label, + description=update_request.description, + session=session, + ) + + logger.info( + f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}" + ) + + return UpdateToolResponse( + message="Tool successfully updated", + tool_id=tool_id, + updated_fields=updated_fields, + ) + + except ValueError as e: + logger.error(f"Validation error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) or "Invalid request", + ) + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}", + ) From 2b5d7506d0e336c214e273eb760223f1730c3e34 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Thu, 9 Oct 2025 16:23:29 +0200 Subject: [PATCH 27/29] feat: add endpoints to get tools per user --- modules/features/chatBot/service.py | 49 +++++++++++++++++ modules/routes/routeChatbot.py | 85 +++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index 9c2f7aa6..e4b63ff9 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -890,3 +890,52 @@ async def update_tool( logger.info(f"Successfully updated tool {tool_id}, fields: {updated_fields}") return updated_fields + + +async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dict]: + """Get all tools granted to a specific user. + + Args: + user_id: The user ID to get tools for. + session: The database session for querying. + + Returns: + List of tool dictionaries with all tool information. + """ + from modules.features.chatBot.database import Tool, UserToolMapping + + logger.info(f"Fetching tools for user {user_id}") + + # Query tools that are granted to the user + # Join UserToolMapping with Tool table + # Filter by user_id and active status + stmt = ( + select(Tool) + .join(UserToolMapping, Tool.id == UserToolMapping.tool_id) + .where( + UserToolMapping.user_id == user_id, + UserToolMapping.is_active == True, + Tool.is_active == True, + ) + .order_by(Tool.category, Tool.name) + ) + result = await session.execute(stmt) + tools = result.scalars().all() + + tool_list = [] + for tool in tools: + tool_dict = { + "id": str(tool.id), + "tool_id": tool.tool_id, + "name": tool.name, + "label": tool.label, + "category": tool.category, + "description": tool.description, + "is_active": tool.is_active, + "date_created": tool.date_created.timestamp(), + "date_updated": tool.date_updated.timestamp(), + } + tool_list.append(tool_dict) + + logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}") + return tool_list diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index b9fbc092..65151017 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -534,3 +534,88 @@ async def update_tool( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}", ) + + +@router.get("/tools/user/{user_id}", response_model=ToolListResponse) +@limiter.limit("30/minute") +async def get_tools_for_specific_user( + *, + request: Request, + user_id: str, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), +) -> ToolListResponse: + """ + Get all tools granted to a specific user. + Only accessible to system administrators. + """ + try: + # Check SYSADMIN permission + if currentUser.privilege != UserPrivilege.SYSADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only system administrators can view user tools", + ) + + # Get tools for the specified user + tools_data = await chat_service.get_tools_for_user( + user_id=user_id, session=session + ) + + # Convert to ToolInfo objects + tools = [ToolInfo(**tool) for tool in tools_data] + + logger.info( + f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}" + ) + + return ToolListResponse(tools=tools) + + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}", + exc_info=True, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}", + ) + + +@router.get("/tools/me", response_model=ToolListResponse) +@limiter.limit("30/minute") +async def get_my_tools( + *, + request: Request, + currentUser: User = Depends(getCurrentUser), + session: AsyncSession = Depends(get_async_db_session), +) -> ToolListResponse: + """ + Get all tools the current user has access to. + """ + try: + # Get tools for the current user + tools_data = await chat_service.get_tools_for_user( + user_id=currentUser.id, session=session + ) + + # Convert to ToolInfo objects + tools = [ToolInfo(**tool) for tool in tools_data] + + logger.info( + f"User {currentUser.id} retrieved {len(tools)} tools for themselves" + ) + + return ToolListResponse(tools=tools) + + except Exception as e: + logger.error( + f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}", + exc_info=True, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}", + ) From ba1daa2d7317d494de91b50e4f95374713701073 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Thu, 9 Oct 2025 16:56:27 +0200 Subject: [PATCH 28/29] feat: allow users to specify tools when posting messages --- modules/datamodels/datamodelChatbot.py | 4 + modules/features/chatBot/service.py | 129 ++++++++++++++++++++----- modules/routes/routeChatbot.py | 36 ++++++- 3 files changed, 144 insertions(+), 25 deletions(-) diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py index d4bb2f3a..762996b3 100644 --- a/modules/datamodels/datamodelChatbot.py +++ b/modules/datamodels/datamodelChatbot.py @@ -20,6 +20,10 @@ class ChatMessageRequest(BaseModel, ModelMixin): None, description="Thread ID (creates new thread if not provided)" ) message: str = Field(..., description="User message content") + tools: Optional[List[str]] = Field( + None, + description="List of tool IDs to use. If not provided, all user's tools will be used", + ) class ChatMessageResponse(BaseModel, ModelMixin): diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py index e4b63ff9..8239ba53 100644 --- a/modules/features/chatBot/service.py +++ b/modules/features/chatBot/service.py @@ -301,6 +301,7 @@ async def post_message( thread_id: str, message: str, user: User, + tool_ids: List[str], ) -> ChatMessageResponse: """Post a chat message to the chatbot and return the response. @@ -308,21 +309,19 @@ async def post_message( thread_id: The unique identifier for the chat thread. message: The content of the chat message. user: The current user. + tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools. Returns: The response containing the full chat message history and thread ID. """ - logger.info(f"User {user.id} posted message to thread {thread_id}") - - # Get user permissions - tool_ids = permissions.get_chatbot_tools(user_id=user.id) - if not tool_ids: - raise ValueError("User does not have permission to use any chatbot tools") + logger.info( + f"User {user.id} posted message to thread {thread_id} with {len(tool_ids)} tools" + ) model_name = permissions.get_chatbot_model(user_id=user.id) system_prompt = permissions.get_system_prompt(user_id=user.id) - # Get tools from registry + # Get tools from registry (empty list if no tools) registry = get_registry() tools = registry.get_tool_instances(tool_ids=tool_ids) @@ -377,6 +376,7 @@ async def post_message_stream( thread_id: str, message: str, user: User, + tool_ids: List[str], ) -> AsyncIterator[str]: """Post a chat message to the chatbot and stream progress updates (SSE). @@ -384,32 +384,20 @@ async def post_message_stream( thread_id: The unique identifier for the chat thread. message: The content of the chat message. user: The current user. + tool_ids: List of tool IDs to use for this chat. Can be empty to run without tools. Yields: Server-Sent Events formatted strings containing status updates and final response. """ - logger.info(f"User {user.id} streaming message to thread {thread_id}") + logger.info( + f"User {user.id} streaming message to thread {thread_id} with {len(tool_ids)} tools" + ) try: - # Get user permissions - tool_ids = permissions.get_chatbot_tools(user_id=user.id) - if not tool_ids: - yield ( - "data: " - + json.dumps( - { - "type": "error", - "message": "User does not have permission to use any chatbot tools", - } - ) - + "\n\n" - ) - return - model_name = permissions.get_chatbot_model(user_id=user.id) system_prompt = permissions.get_system_prompt(user_id=user.id) - # Get tools from registry + # Get tools from registry (empty list if no tools) registry = get_registry() tools = registry.get_tool_instances(tool_ids=tool_ids) @@ -939,3 +927,96 @@ async def get_tools_for_user(*, user_id: str, session: AsyncSession) -> List[dic logger.info(f"Retrieved {len(tool_list)} tools for user {user_id}") return tool_list + + +async def validate_and_get_tools_for_request( + *, + user_id: str, + requested_tool_ids: Optional[List[str]], + session: AsyncSession, +) -> List[str]: + """Validate and get tool IDs for a chat request. + + This function validates that the user has access to the requested tools. + If no tools are requested (None), it returns all tools the user has access to. + If an empty list is provided, it returns an empty list (no tools). + + Args: + user_id: The user ID making the request. + requested_tool_ids: Optional list of tool UUIDs (id field) requested by the user. + - None: Use all tools the user has access to + - []: Use no tools at all + - ["uuid1", "uuid2"]: Use only the specified tools + session: The database session for querying. + + Returns: + List of validated tool IDs (tool_id field, not UUID) that the user can use. + + Raises: + PermissionError: If the user requests tools they don't have access to. + ValueError: If the user has no tools available when trying to use all tools. + """ + from modules.features.chatBot.database import Tool, UserToolMapping + import uuid + + logger.info(f"Validating tools for user {user_id}") + + # If empty list is explicitly provided, return empty list (no tools) + if requested_tool_ids is not None and len(requested_tool_ids) == 0: + logger.info( + f"Empty tool list requested, chatbot will run without tools for user {user_id}" + ) + return [] + + # Get all tools the user has access to + stmt = ( + select(Tool) + .join(UserToolMapping, Tool.id == UserToolMapping.tool_id) + .where( + UserToolMapping.user_id == user_id, + UserToolMapping.is_active == True, + Tool.is_active == True, + ) + ) + result = await session.execute(stmt) + user_tools = result.scalars().all() + + # Create mappings for both UUID and tool_id + user_tool_ids_by_uuid = {str(tool.id): tool.tool_id for tool in user_tools} + user_tool_ids = set(user_tool_ids_by_uuid.values()) + + if not user_tool_ids: + logger.warning(f"User {user_id} has no tools available") + raise ValueError("User does not have access to any chatbot tools") + + # If no specific tools requested (None), return all user's tools + if requested_tool_ids is None: + logger.info( + f"No specific tools requested, returning all {len(user_tool_ids)} tools for user {user_id}" + ) + return list(user_tool_ids) + + # Convert requested UUIDs to tool_ids and validate access + requested_tool_ids_result = [] + unauthorized_uuids = [] + + for requested_uuid in requested_tool_ids: + if requested_uuid in user_tool_ids_by_uuid: + # User has access to this tool + requested_tool_ids_result.append(user_tool_ids_by_uuid[requested_uuid]) + else: + # User doesn't have access to this tool + unauthorized_uuids.append(requested_uuid) + + if unauthorized_uuids: + logger.warning( + f"User {user_id} requested unauthorized tool UUIDs: {unauthorized_uuids}" + ) + raise PermissionError( + f"You do not have access to the following tools: {', '.join(unauthorized_uuids)}" + ) + + logger.info( + f"Validated {len(requested_tool_ids_result)} requested tools for user {user_id}" + ) + return requested_tool_ids_result diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index 65151017..a4757c84 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -62,6 +62,13 @@ async def post_chat_message_stream( Returns Server-Sent Events (SSE) stream with status updates and final response. """ try: + # Validate and get tools for the request + tool_ids = await chat_service.validate_and_get_tools_for_request( + user_id=currentUser.id, + requested_tool_ids=message_request.tools, + session=session, + ) + # Get or create thread using helper function thread_id = await get_or_create_thread_for_user( thread_id=message_request.thread_id, @@ -80,6 +87,7 @@ async def post_chat_message_stream( thread_id=thread_id, message=message_request.message, user=currentUser, + tool_ids=tool_ids, ), media_type="text/event-stream", headers={ @@ -88,6 +96,18 @@ async def post_chat_message_stream( }, ) + except PermissionError as e: + logger.error(f"Permission error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) + except ValueError as e: + logger.error(f"Validation error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) except Exception as e: logger.error( f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True @@ -114,6 +134,13 @@ async def post_chat_message( For streaming updates, use the /message/stream endpoint instead. """ try: + # Validate and get tools for the request + tool_ids = await chat_service.validate_and_get_tools_for_request( + user_id=currentUser.id, + requested_tool_ids=message_request.tools, + session=session, + ) + # Get or create thread using helper function thread_id = await get_or_create_thread_for_user( thread_id=message_request.thread_id, @@ -129,16 +156,23 @@ async def post_chat_message( thread_id=thread_id, message=message_request.message, user=currentUser, + tool_ids=tool_ids, ) return response - except ValueError as e: + except PermissionError as e: logger.error(f"Permission error: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(e) or "Permission denied", ) + except ValueError as e: + logger.error(f"Validation error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) or "Permission denied", + ) except Exception as e: logger.error( f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True From 31f8192bd3304895508f0935f51a8391c1a392e9 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Mon, 13 Oct 2025 17:03:10 +0200 Subject: [PATCH 29/29] feat: add valueon powerbi tool --- .../customerTools/toolValueOnPowerBi.py | 362 ++++++++++++++++++ 1 file changed, 362 insertions(+) create mode 100644 modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py diff --git a/modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py b/modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py new file mode 100644 index 00000000..7c00ee6a --- /dev/null +++ b/modules/features/chatBot/chatbotTools/customerTools/toolValueOnPowerBi.py @@ -0,0 +1,362 @@ +"""Power BI Query Tool for LangGraph. + +This tool provides DAX query capabilities for Power BI datasets +via the Power BI REST API. Only read-only queries are allowed. +""" + +import logging +import asyncio +import os +import re +import functools +from typing import Annotated + +import anyio +import httpx +from langchain_core.tools import tool +from msal import ConfidentialClientApplication, SerializableTokenCache + +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + + +# Configuration constants - encapsulated in this file +POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "") +POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "") +POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "") +POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "") +POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg" +POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com" +POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"] + +# Limit results to prevent excessive context usage +MAX_ROWS_LIMIT = 100 + + +def _validate_environment() -> tuple[bool, str]: + """Validate that all required environment variables are set. + + Returns: + A tuple of (is_valid, error_message) + """ + missing = [] + if not POWERBI_DATASET_ID: + missing.append("POWERBI_DATASET_ID") + if not POWERBI_CLIENT_ID: + missing.append("POWERBI_CLIENT_ID") + if not POWERBI_CLIENT_SECRET: + missing.append("POWERBI_CLIENT_SECRET") + if not POWERBI_TENANT_ID: + missing.append("POWERBI_TENANT_ID") + + if missing: + return False, f"Missing required environment variables: {', '.join(missing)}" + + return True, "" + + +def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]: + """Validate that the query is a valid DAX query. + + Args: + dax_query: The DAX query to validate + + Returns: + A tuple of (is_valid, error_message) + """ + # Remove leading/trailing whitespace + normalized_query = dax_query.strip() + + if not normalized_query: + return False, "Query cannot be empty" + + # DAX queries typically start with EVALUATE, DEFINE, or are table expressions + # We'll be lenient and just check it's not trying to do something dangerous + # DAX is read-only by nature, but we validate structure + + # Check for minimum length + if len(normalized_query) < 5: + return False, "Query is too short to be valid" + + return True, "" + + +def _get_access_token_sync( + *, + tenant_id: str, + client_id: str, + client_secret: str, + authority_base: str = POWERBI_AUTHORITY_BASE, + cache: SerializableTokenCache | None = None, +) -> str: + """Get Power BI access token using MSAL (synchronous). + + Args: + tenant_id: Azure AD tenant ID + client_id: Application client ID + client_secret: Application client secret + authority_base: Azure AD authority base URL + cache: Optional token cache for reuse + + Returns: + Access token string + + Raises: + RuntimeError: If token acquisition fails + """ + authority = f"{authority_base}/{tenant_id}" + + app = ConfidentialClientApplication( + client_id=client_id, + authority=authority, + client_credential=client_secret, + token_cache=cache, + ) + + # Try cache first; fall back to client credentials + result = app.acquire_token_silent( + POWERBI_SCOPE, account=None + ) or app.acquire_token_for_client(scopes=POWERBI_SCOPE) + + if "access_token" not in result: + raise RuntimeError( + f"MSAL token error: {result.get('error')} - {result.get('error_description')}" + ) + + return result["access_token"] + + +async def _get_access_token_async( + *, + tenant_id: str, + client_id: str, + client_secret: str, + **kwargs, +) -> str: + """Get Power BI access token using MSAL (asynchronous). + + Args: + tenant_id: Azure AD tenant ID + client_id: Application client ID + client_secret: Application client secret + **kwargs: Additional arguments for _get_access_token_sync + + Returns: + Access token string + """ + # Create a partial function with arguments pre-filled + func = functools.partial( + _get_access_token_sync, + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + **kwargs, + ) + # Offload the blocking MSAL HTTP call to a worker thread + return await anyio.to_thread.run_sync(func) + + +async def _execute_dax_query( + *, dax_query: str, dataset_id: str, access_token: str +) -> dict: + """Execute a DAX query against Power BI dataset. + + Args: + dax_query: The DAX query to execute + dataset_id: Power BI dataset ID + access_token: Access token for authentication + + Returns: + Dictionary containing query results + + Raises: + RuntimeError: If query execution fails + """ + url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries" + + body = { + "queries": [{"query": dax_query}], + "serializerSettings": {"includeNulls": True}, + } + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(url, headers=headers, json=body) + + if resp.status_code != 200: + raise RuntimeError( + f"Power BI executeQueries failed: {resp.status_code} - {resp.text}" + ) + + payload = resp.json() + + try: + rows = payload["results"][0]["tables"][0]["rows"] + except (KeyError, IndexError) as e: + raise RuntimeError("Unexpected executeQueries response structure") from e + + # Extract column names from the first row if available + if rows: + columns = list(rows[0].keys()) + else: + columns = [] + + return {"columns": columns, "rows": rows} + + +def _strip_table_qualifier(*, column_name: str) -> str: + """Strip table qualifier from column name. + + Power BI often returns columns as 'Table[Column]'. This strips to 'Column'. + + Args: + column_name: The column name to process + + Returns: + Processed column name + """ + if "[" in column_name and column_name.endswith("]"): + return column_name.split("[", 1)[1][:-1] + return column_name + + +def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str: + """Format query results into a readable string. + + Args: + columns: List of column names + rows: List of row data (as dictionaries) + max_rows: Maximum number of rows to display + + Returns: + Formatted string representation of the results + """ + total_rows = len(rows) + + if total_rows == 0: + return "Query executed successfully but returned no results." + + # Strip table qualifiers from column names + clean_columns = [_strip_table_qualifier(column_name=col) for col in columns] + + # Limit rows to max_rows + display_rows = rows[:max_rows] + truncated = total_rows > max_rows + + # Calculate column widths + col_widths = [len(str(col)) for col in clean_columns] + for row in display_rows: + for i, col in enumerate(columns): + value = row.get(col, "") + col_widths[i] = max(col_widths[i], len(str(value))) + + # Build header + header_parts = [] + for col, width in zip(clean_columns, col_widths): + header_parts.append(str(col).ljust(width)) + header = " | ".join(header_parts) + separator = "-" * len(header) + + # Build rows + row_lines = [] + for row in display_rows: + row_parts = [] + for col, width in zip(columns, col_widths): + value = row.get(col, "") + row_parts.append(str(value).ljust(width)) + row_lines.append(" | ".join(row_parts)) + + # Combine all parts + result_parts = [ + f"Query returned {total_rows} row(s):", + ] + + if truncated: + result_parts.append( + f"(Results limited to {max_rows} rows for context efficiency)\n" + ) + else: + result_parts.append("") + + result_parts.extend([header, separator, "\n".join(row_lines)]) + + return "\n".join(result_parts) + + +@tool +async def query_powerbi_data( + dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"], +) -> str: + """Execute a DAX query against the Power BI dataset to access warehouse inventory data. + + This tool provides access to a Power BI table called 'data_full' which contains + articles available in the warehouse of the user. Use DAX (Data Analysis Expressions) + queries to retrieve and analyze this inventory data. + + Available table: + - 'data_full': Contains warehouse inventory articles and their details + + Common query patterns: + - View all data: EVALUATE 'data_full' + - With filter: EVALUATE FILTER('data_full', [Column] = "Value") + - Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC) + - Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2])) + + Results are limited to 100 rows maximum for efficiency. + + Args: + dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'") + + Returns: + A formatted string containing the query results with columns and rows + """ + try: + # Validate environment configuration + is_valid_env, error_msg = _validate_environment() + if not is_valid_env: + logger.error(f"Environment validation failed: {error_msg}") + return f"Configuration Error: {error_msg}" + + # Validate the query + is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query) + if not is_valid_query: + logger.warning(f"Invalid query attempt: {dax_query[:100]}...") + return f"Query Validation Error: {error_msg}" + + logger.info(f"Executing Power BI query: {dax_query[:100]}...") + + # Get access token + access_token = await _get_access_token_async( + tenant_id=POWERBI_TENANT_ID, + client_id=POWERBI_CLIENT_ID, + client_secret=POWERBI_CLIENT_SECRET, + ) + + # Execute the query + result = await _execute_dax_query( + dax_query=dax_query, + dataset_id=POWERBI_DATASET_ID, + access_token=access_token, + ) + + # Format and return results + formatted_output = _format_results( + columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT + ) + + logger.info( + f"Query completed successfully, returned {len(result['rows'])} row(s)" + ) + return formatted_output + + except RuntimeError as e: + logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}") + return f"Error executing query: {str(e)}" + except Exception as e: + logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}") + return f"Unexpected error: {str(e)}"