This commit is contained in:
ValueOn AG 2025-10-15 12:41:08 +02:00
commit 82b2fd36dc
26 changed files with 5672 additions and 838 deletions

270
app.py
View file

@ -1,10 +1,22 @@
import os
import sys
import asyncio
from urllib.parse import quote_plus
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from modules.features.chatBot.database import init_models as init_chatbot_models
os.environ["NUMEXPR_MAX_THREADS"] = "12"
# Fix for Windows asyncio compatibility with psycopg
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from fastapi import FastAPI, HTTPException, Depends, Body, status, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.security import HTTPBearer
from contextlib import asynccontextmanager
import logging
from logging.handlers import RotatingFileHandler
@ -20,32 +32,36 @@ class DailyRotatingFileHandler(RotatingFileHandler):
A rotating file handler that automatically switches to a new file when the date changes.
The log file name includes the current date and switches at midnight.
"""
def __init__(self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs):
def __init__(
self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs
):
self.log_dir = log_dir
self.filename_prefix = filename_prefix
self.current_date = None
self.current_file = None
# Initialize with today's file
self._update_file_if_needed()
# Call parent constructor with current file
super().__init__(self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs)
super().__init__(
self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs
)
def _update_file_if_needed(self):
"""Update the log file if the date has changed"""
today = datetime.now().strftime("%Y%m%d")
if self.current_date != today:
self.current_date = today
new_file = os.path.join(self.log_dir, f"{self.filename_prefix}_{today}.log")
if self.current_file != new_file:
self.current_file = new_file
return True
return False
def emit(self, record):
"""Emit a log record, switching files if date has changed"""
# Check if we need to switch to a new file
@ -54,16 +70,17 @@ class DailyRotatingFileHandler(RotatingFileHandler):
if self.stream:
self.stream.close()
self.stream = None
# Update the baseFilename for the parent class
self.baseFilename = self.current_file
# Reopen the stream
if not self.delay:
self.stream = self._open()
# Call parent emit method
super().emit(record)
def initLogging():
"""Initialize logging with configuration from APP_CONFIG"""
# Get log level from config (default to INFO if not found)
@ -76,33 +93,39 @@ def initLogging():
# If relative path, make it relative to the gateway directory
gatewayDir = os.path.dirname(os.path.abspath(__file__))
logDir = os.path.join(gatewayDir, logDir)
# Ensure log directory exists
os.makedirs(logDir, exist_ok=True)
# Create formatters - using single line format
consoleFormatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
)
# File formatter with more detailed error information but still single line
fileFormatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s - %(pathname)s:%(lineno)d - %(funcName)s",
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
)
# Add filter to exclude Chrome DevTools requests
class ChromeDevToolsFilter(logging.Filter):
def filter(self, record):
return not (isinstance(record.msg, str) and
('.well-known/appspecific/com.chrome.devtools.json' in record.msg or
'Request: /index.html' in record.msg))
return not (
isinstance(record.msg, str)
and (
".well-known/appspecific/com.chrome.devtools.json" in record.msg
or "Request: /index.html" in record.msg
)
)
# Add filter to exclude all httpcore loggers (including sub-loggers)
class HttpcoreStarFilter(logging.Filter):
def filter(self, record):
return not (record.name == 'httpcore' or record.name.startswith('httpcore.'))
return not (
record.name == "httpcore" or record.name.startswith("httpcore.")
)
# Add filter to exclude HTTP debug messages
class HTTPDebugFilter(logging.Filter):
@ -110,14 +133,14 @@ def initLogging():
if isinstance(record.msg, str):
# Filter out HTTP debug messages
http_debug_patterns = [
'receive_response_body.started',
'receive_response_body.complete',
'response_closed.started',
'_send_single_request',
'httpcore.http11',
'httpx._client',
'HTTP Request',
'multipart.multipart'
"receive_response_body.started",
"receive_response_body.complete",
"response_closed.started",
"_send_single_request",
"httpcore.http11",
"httpx._client",
"HTTP Request",
"multipart.multipart",
]
return not any(pattern in record.msg for pattern in http_debug_patterns)
return True
@ -129,13 +152,21 @@ def initLogging():
# Remove only emojis, preserve other Unicode characters like quotes
import re
import unicodedata
# Remove emoji characters specifically
record.msg = ''.join(char for char in record.msg if unicodedata.category(char) != 'So' or not (0x1F600 <= ord(char) <= 0x1F64F or 0x1F300 <= ord(char) <= 0x1F5FF or 0x1F680 <= ord(char) <= 0x1F6FF or 0x1F1E0 <= ord(char) <= 0x1F1FF or 0x2600 <= ord(char) <= 0x26FF or 0x2700 <= ord(char) <= 0x27BF))
# Additionally strip characters not representable in Windows cp1252 (e.g., arrows)
try:
record.msg.encode('cp1252', errors='strict')
except UnicodeEncodeError:
record.msg = record.msg.encode('cp1252', errors='ignore').decode('cp1252', errors='ignore')
record.msg = "".join(
char
for char in record.msg
if unicodedata.category(char) != "So"
or not (
0x1F600 <= ord(char) <= 0x1F64F
or 0x1F300 <= ord(char) <= 0x1F5FF
or 0x1F680 <= ord(char) <= 0x1F6FF
or 0x1F1E0 <= ord(char) <= 0x1F1FF
or 0x2600 <= ord(char) <= 0x26FF
or 0x2700 <= ord(char) <= 0x27BF
)
)
return True
# Configure handlers based on config
@ -154,14 +185,16 @@ def initLogging():
# Add file handler if enabled
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
# Create daily application log file with automatic date switching
rotationSize = int(APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)) # Default: 10MB
rotationSize = int(
APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)
) # Default: 10MB
backupCount = int(APP_CONFIG.get("APP_LOGGING_BACKUP_COUNT", 5))
fileHandler = DailyRotatingFileHandler(
log_dir=logDir,
filename_prefix="log_app",
max_bytes=rotationSize,
backup_count=backupCount
backup_count=backupCount,
)
fileHandler.setFormatter(fileFormatter)
fileHandler.addFilter(ChromeDevToolsFilter())
@ -176,11 +209,18 @@ def initLogging():
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
handlers=handlers,
force=True # Force reconfiguration of the root logger
force=True, # Force reconfiguration of the root logger
)
# Silence noisy third-party libraries - use the same level as the root logger
noisyLoggers = ["httpx", "httpcore", "urllib3", "asyncio", "fastapi.security.oauth2", "msal"]
noisyLoggers = [
"httpx",
"httpcore",
"urllib3",
"asyncio",
"fastapi.security.oauth2",
"msal",
]
for loggerName in noisyLoggers:
logging.getLogger(loggerName).setLevel(logging.WARNING)
@ -188,38 +228,142 @@ def initLogging():
logger = logging.getLogger(__name__)
logger.info(f"Logging initialized with level {logLevelName}")
logger.info(f"Log directory: {logDir}")
if APP_CONFIG.get('APP_LOGGING_FILE_ENABLED', True):
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
today = datetime.now().strftime("%Y%m%d")
appLogFile = os.path.join(logDir, f"log_app_{today}.log")
logger.info(f"Application log file: {appLogFile} (auto-switches daily)")
else:
logger.info("Application log file: disabled")
logger.info(f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}")
logger.info(
f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}"
)
def make_sqlalchemy_db_url() -> str:
host = APP_CONFIG.get("SQLALCHEMY_DB_HOST", "localhost")
port = APP_CONFIG.get("SQLALCHEMY_DB_PORT", "5432")
db = APP_CONFIG.get("SQLALCHEMY_DB_DATABASE", "project_gateway")
user = APP_CONFIG.get("SQLALCHEMY_DB_USER", "postgres")
pwd = quote_plus(APP_CONFIG.get("SQLALCHEMY_DB_PASSWORD_SECRET", ""))
return f"postgresql+psycopg://{user}:{pwd}@{host}:{port}/{db}"
# Initialize logging
initLogging()
logger = logging.getLogger(__name__)
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
# Define lifespan context manager for application startup/shutdown events
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Application is starting up")
# --- Init SQLAlchemy ---
engine = create_async_engine(
make_sqlalchemy_db_url(), pool_pre_ping=True, echo=False
)
SessionLocal = async_sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession
)
app.state.checkpoint_engine = engine
app.state.checkpoint_sessionmaker = SessionLocal
# NOTE: Might need Alembic migrations in the future
await init_chatbot_models(engine)
# --- Sync tools from registry to database ---
from modules.features.chatBot.database import sync_tools_from_registry
async with SessionLocal() as session:
await sync_tools_from_registry(session)
await session.commit()
logger.info("Tools synced from registry to database")
# --- Initialize LangGraph checkpointer ---
from modules.features.chatBot.utils.checkpointer import (
initialize_checkpointer,
close_checkpointer,
)
try:
await initialize_checkpointer()
logger.info("LangGraph checkpointer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}")
# Continue startup even if checkpointer fails to initialize
# --- Init Event Manager ---
eventManager.start()
yield
# --- Cleanup Event Manager ---
eventManager.stop()
# --- Cleanup LangGraph checkpointer ---
await close_checkpointer()
# --- Cleanup SQLAlchemy ---
await engine.dispose()
logger.info("Application has been shut down")
# START APP
app = FastAPI(
title="PowerOn | Data Platform API",
title="PowerOn | Data Platform API",
description=f"Backend API for the Multi-Agent Platform by ValueOn AG ({instanceLabel})",
lifespan=lifespan
lifespan=lifespan,
swagger_ui_init_oauth={
"usePkceWithAuthorizationCodeGrant": True,
},
)
# Configure OpenAPI security scheme for Swagger UI
# This adds the "Authorize" button to the /docs page
security_scheme = HTTPBearer()
app.openapi_schema = None # Reset schema to regenerate with security
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version="1.0.0",
description=app.description,
routes=app.routes,
)
# Add security scheme definition
openapi_schema["components"]["securitySchemes"] = {
"BearerAuth": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Enter your JWT token (obtained from login endpoint or browser cookies)",
}
}
# Apply security globally to all endpoints
# Individual endpoints can override this if needed
openapi_schema["security"] = [{"BearerAuth": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
# Parse CORS origins from environment variable
def get_allowed_origins():
origins_str = APP_CONFIG.get("APP_ALLOWED_ORIGINS", "http://localhost:8080")
@ -228,73 +372,99 @@ def get_allowed_origins():
logger.info(f"CORS allowed origins: {origins}")
return origins
# CORS configuration using environment variables
app.add_middleware(
CORSMiddleware,
allow_origins= get_allowed_origins(),
allow_origins=get_allowed_origins(),
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["*"],
max_age=86400 # Increased caching for preflight requests
max_age=86400, # Increased caching for preflight requests
)
# CSRF protection middleware
from modules.security.csrf import CSRFMiddleware
from modules.security.tokenRefreshMiddleware import TokenRefreshMiddleware, ProactiveTokenRefreshMiddleware
from modules.security.tokenRefreshMiddleware import (
TokenRefreshMiddleware,
ProactiveTokenRefreshMiddleware,
)
app.add_middleware(CSRFMiddleware)
# Token refresh middleware (silent refresh for expired OAuth tokens)
app.add_middleware(TokenRefreshMiddleware, enabled=True)
# Proactive token refresh middleware (refresh tokens before they expire)
app.add_middleware(ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5)
app.add_middleware(
ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5
)
# Run triggered features
import modules.features.init
# Include all routers
from modules.routes.routeAdmin import router as generalRouter
app.include_router(generalRouter)
from modules.routes.routeAttributes import router as attributesRouter
app.include_router(attributesRouter)
from modules.routes.routeDataMandates import router as mandateRouter
app.include_router(mandateRouter)
from modules.routes.routeDataUsers import router as userRouter
app.include_router(userRouter)
from modules.routes.routeDataFiles import router as fileRouter
app.include_router(fileRouter)
from modules.routes.routeDataNeutralization import router as neutralizationRouter
app.include_router(neutralizationRouter)
from modules.routes.routeDataPrompts import router as promptRouter
app.include_router(promptRouter)
from modules.routes.routeDataConnections import router as connectionsRouter
app.include_router(connectionsRouter)
from modules.routes.routeWorkflows import router as workflowRouter
app.include_router(workflowRouter)
from modules.routes.routeChatPlayground import router as chatPlaygroundRouter
app.include_router(chatPlaygroundRouter)
from modules.routes.routeSecurityLocal import router as localRouter
app.include_router(localRouter)
from modules.routes.routeSecurityMsft import router as msftRouter
app.include_router(msftRouter)
from modules.routes.routeSecurityGoogle import router as googleRouter
app.include_router(googleRouter)
from modules.routes.routeVoiceGoogle import router as voiceGoogleRouter
app.include_router(voiceGoogleRouter)
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
app.include_router(adminSecurityRouter)
app.include_router(adminSecurityRouter)
from modules.routes.routeChatbot import router as chatbotRouter
app.include_router(chatbotRouter)

File diff suppressed because it is too large Load diff

View file

@ -1,21 +1,33 @@
"""Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument."""
from typing import List, Dict, Any, Optional
from enum import Enum
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import register_model_labels, ModelMixin
from modules.shared.timezoneUtils import get_utc_timestamp
import uuid
class ChatStat(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
workflowId: Optional[str] = Field(
None, description="Foreign key to workflow (for workflow stats)"
)
messageId: Optional[str] = Field(
None, description="Foreign key to message (for message stats)"
)
processingTime: Optional[float] = Field(
None, description="Processing time in seconds"
)
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
successRate: Optional[float] = Field(None, description="Success rate of operations")
errorCount: Optional[int] = Field(None, description="Number of errors encountered")
register_model_labels(
"ChatStat",
{"en": "Chat Statistics", "fr": "Statistiques de chat"},
@ -32,15 +44,27 @@ register_model_labels(
},
)
class ChatLog(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
workflowId: str = Field(description="Foreign key to workflow")
message: str = Field(description="Log message")
type: str = Field(description="Log type (info, warning, error, etc.)")
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)")
timestamp: float = Field(
default_factory=get_utc_timestamp,
description="When the log entry was created (UTC timestamp in seconds)",
)
status: Optional[str] = Field(None, description="Status of the log entry")
progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)")
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
progress: Optional[float] = Field(
None, description="Progress indicator (0.0 to 1.0)"
)
performance: Optional[Dict[str, Any]] = Field(
None, description="Performance metrics"
)
register_model_labels(
"ChatLog",
{"en": "Chat Log", "fr": "Journal de chat"},
@ -56,8 +80,11 @@ register_model_labels(
},
)
class ChatDocument(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
messageId: str = Field(description="Foreign key to message")
fileId: str = Field(description="Foreign key to file")
fileName: str = Field(description="Name of the file")
@ -66,7 +93,11 @@ class ChatDocument(BaseModel, ModelMixin):
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
taskNumber: Optional[int] = Field(None, description="Task number within round")
actionNumber: Optional[int] = Field(None, description="Action number within task")
actionId: Optional[str] = Field(None, description="ID of the action that created this document")
actionId: Optional[str] = Field(
None, description="ID of the action that created this document"
)
register_model_labels(
"ChatDocument",
{"en": "Chat Document", "fr": "Document de chat"},
@ -84,17 +115,26 @@ register_model_labels(
},
)
class ContentMetadata(BaseModel, ModelMixin):
size: int = Field(description="Content size in bytes")
pages: Optional[int] = Field(None, description="Number of pages for multi-page content")
pages: Optional[int] = Field(
None, description="Number of pages for multi-page content"
)
error: Optional[str] = Field(None, description="Processing error if any")
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
height: Optional[int] = Field(None, description="Height in pixels for images/videos")
height: Optional[int] = Field(
None, description="Height in pixels for images/videos"
)
colorMode: Optional[str] = Field(None, description="Color mode")
fps: Optional[float] = Field(None, description="Frames per second for videos")
durationSec: Optional[float] = Field(None, description="Duration in seconds for media")
durationSec: Optional[float] = Field(
None, description="Duration in seconds for media"
)
mimeType: str = Field(description="MIME type of the content")
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
register_model_labels(
"ContentMetadata",
{"en": "Content Metadata", "fr": "Métadonnées du contenu"},
@ -112,10 +152,13 @@ register_model_labels(
},
)
class ContentItem(BaseModel, ModelMixin):
label: str = Field(description="Content label")
data: str = Field(description="Extracted text content")
metadata: ContentMetadata = Field(description="Content metadata")
register_model_labels(
"ContentItem",
{"en": "Content Item", "fr": "Élément de contenu"},
@ -126,9 +169,14 @@ register_model_labels(
},
)
class ChatContentExtracted(BaseModel, ModelMixin):
id: str = Field(description="Reference to source ChatDocument")
contents: List[ContentItem] = Field(default_factory=list, description="List of content items")
contents: List[ContentItem] = Field(
default_factory=list, description="List of content items"
)
register_model_labels(
"ChatContentExtracted",
{"en": "Extracted Content", "fr": "Contenu extrait"},
@ -138,28 +186,58 @@ register_model_labels(
},
)
class ChatMessage(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
workflowId: str = Field(description="Foreign key to workflow")
parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading")
documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents")
documentsLabel: Optional[str] = Field(None, description="Label for the set of documents")
parentMessageId: Optional[str] = Field(
None, description="Parent message ID for threading"
)
documents: List[ChatDocument] = Field(
default_factory=list, description="Associated documents"
)
documentsLabel: Optional[str] = Field(
None, description="Label for the set of documents"
)
message: Optional[str] = Field(None, description="Message content")
summary: Optional[str] = Field(None, description="Short summary of this message for planning/history")
summary: Optional[str] = Field(
None, description="Short summary of this message for planning/history"
)
role: str = Field(description="Role of the message sender")
status: str = Field(description="Status of the message (first, step, last)")
sequenceNr: int = Field(description="Sequence number of the message (set automatically)")
publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)")
sequenceNr: int = Field(
description="Sequence number of the message (set automatically)"
)
publishedAt: float = Field(
default_factory=get_utc_timestamp,
description="When the message was published (UTC timestamp in seconds)",
)
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
actionId: Optional[str] = Field(None, description="ID of the action that produced this message")
actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message")
actionName: Optional[str] = Field(None, description="Name of the action that produced this message")
success: Optional[bool] = Field(
None, description="Whether the message processing was successful"
)
actionId: Optional[str] = Field(
None, description="ID of the action that produced this message"
)
actionMethod: Optional[str] = Field(
None, description="Method of the action that produced this message"
)
actionName: Optional[str] = Field(
None, description="Name of the action that produced this message"
)
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
taskNumber: Optional[int] = Field(None, description="Task number within round")
actionNumber: Optional[int] = Field(None, description="Action number within task")
taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry")
actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail")
taskProgress: Optional[str] = Field(
None, description="Task progress status: pending, running, success, fail, retry"
)
actionProgress: Optional[str] = Field(
None, description="Action progress status: pending, running, success, fail"
)
register_model_labels(
"ChatMessage",
{"en": "Chat Message", "fr": "Message de chat"},
@ -188,32 +266,139 @@ register_model_labels(
},
)
class ChatWorkflow(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
])
name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True)
currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False)
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False)
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
{"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}},
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
])
maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False)
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Primary key",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
mandateId: str = Field(
description="ID of the mandate this workflow belongs to",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
status: str = Field(
description="Current status of the workflow",
frontend_type="select",
frontend_readonly=False,
frontend_required=False,
frontend_options=[
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
],
)
name: Optional[str] = Field(
None,
description="Name of the workflow",
frontend_type="text",
frontend_readonly=False,
frontend_required=True,
)
currentRound: int = Field(
description="Current round number",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
currentTask: int = Field(
default=0,
description="Current task number",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
currentAction: int = Field(
default=0,
description="Current action number",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
totalTasks: int = Field(
default=0,
description="Total number of tasks in the workflow",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
totalActions: int = Field(
default=0,
description="Total number of actions in the workflow",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
lastActivity: float = Field(
default_factory=get_utc_timestamp,
description="Timestamp of last activity (UTC timestamp in seconds)",
frontend_type="timestamp",
frontend_readonly=True,
frontend_required=False,
)
startedAt: float = Field(
default_factory=get_utc_timestamp,
description="When the workflow started (UTC timestamp in seconds)",
frontend_type="timestamp",
frontend_readonly=True,
frontend_required=False,
)
logs: List[ChatLog] = Field(
default_factory=list,
description="Workflow logs",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
messages: List[ChatMessage] = Field(
default_factory=list,
description="Messages in the workflow",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
stats: Optional[ChatStat] = Field(
None,
description="Workflow statistics",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
tasks: list = Field(
default_factory=list,
description="List of tasks in the workflow",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
workflowMode: str = Field(
default="Actionplan",
description="Workflow mode selector",
frontend_type="select",
frontend_readonly=False,
frontend_required=False,
frontend_options=[
{
"value": "Actionplan",
"label": {"en": "Action Plan", "fr": "Plan d'actions"},
},
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
],
)
maxSteps: int = Field(
default=5,
description="Maximum number of iterations in react mode",
frontend_type="integer",
frontend_readonly=False,
frontend_required=False,
)
register_model_labels(
"ChatWorkflow",
{"en": "Chat Workflow", "fr": "Flux de travail de chat"},
@ -238,12 +423,16 @@ register_model_labels(
},
)
class UserInputRequest(BaseModel, ModelMixin):
prompt: str = Field(description="Prompt for the user")
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
userLanguage: str = Field(default="en", description="User's preferred language")
register_model_labels(
"UserInputRequest", {"en": "User Input Request", "fr": "Demande de saisie utilisateur"},
"UserInputRequest",
{"en": "User Input Request", "fr": "Demande de saisie utilisateur"},
{
"prompt": {"en": "Prompt", "fr": "Invite"},
"listFileId": {"en": "File IDs", "fr": "IDs des fichiers"},
@ -251,11 +440,15 @@ register_model_labels(
},
)
class ActionDocument(BaseModel, ModelMixin):
"""Clear document structure for action results"""
documentName: str = Field(description="Name of the document")
documentData: Any = Field(description="Content/data of the document")
mimeType: str = Field(description="MIME type of the document")
register_model_labels(
"ActionDocument",
{"en": "Action Document", "fr": "Document d'action"},
@ -266,6 +459,7 @@ register_model_labels(
},
)
class ActionResult(BaseModel, ModelMixin):
"""Clean action result with documents as primary output
@ -276,16 +470,25 @@ class ActionResult(BaseModel, ModelMixin):
success: bool = Field(description="Whether execution succeeded")
error: Optional[str] = Field(None, description="Error message if failed")
documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs")
resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)")
documents: List[ActionDocument] = Field(
default_factory=list, description="Document outputs"
)
resultLabel: Optional[str] = Field(
None,
description="Label for document routing (set by action handler, not by action methods)",
)
@classmethod
def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult":
return cls(success=True, documents=documents or [])
@classmethod
def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> "ActionResult":
def isFailure(
cls, error: str, documents: List[ActionDocument] = None
) -> "ActionResult":
return cls(success=False, documents=documents or [], error=error)
register_model_labels(
"ActionResult",
{"en": "Action Result", "fr": "Résultat de l'action"},
@ -297,9 +500,14 @@ register_model_labels(
},
)
class ActionSelection(BaseModel, ModelMixin):
method: str = Field(description="Method to execute (e.g., web, document, ai)")
name: str = Field(description="Action name within the method (e.g., search, extract)")
name: str = Field(
description="Action name within the method (e.g., search, extract)"
)
register_model_labels(
"ActionSelection",
{"en": "Action Selection", "fr": "Sélection d'action"},
@ -309,8 +517,13 @@ register_model_labels(
},
)
class ActionParameters(BaseModel, ModelMixin):
parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to execute the selected action")
parameters: Dict[str, Any] = Field(
default_factory=dict, description="Parameters to execute the selected action"
)
register_model_labels(
"ActionParameters",
{"en": "Action Parameters", "fr": "Paramètres d'action"},
@ -319,10 +532,13 @@ register_model_labels(
},
)
class ObservationPreview(BaseModel, ModelMixin):
name: str = Field(description="Document name or URL label")
mime: str = Field(description="MIME type or kind")
snippet: str = Field(description="Short snippet or summary")
register_model_labels(
"ObservationPreview",
{"en": "Observation Preview", "fr": "Aperçu d'observation"},
@ -333,12 +549,19 @@ register_model_labels(
},
)
class Observation(BaseModel, ModelMixin):
success: bool = Field(description="Action execution success flag")
resultLabel: str = Field(description="Deterministic label for produced documents")
documentsCount: int = Field(description="Number of produced documents")
previews: List[ObservationPreview] = Field(default_factory=list, description="Compact previews of outputs")
notes: List[str] = Field(default_factory=list, description="Short notes or key facts")
previews: List[ObservationPreview] = Field(
default_factory=list, description="Compact previews of outputs"
)
notes: List[str] = Field(
default_factory=list, description="Short notes or key facts"
)
register_model_labels(
"Observation",
{"en": "Observation", "fr": "Observation"},
@ -351,12 +574,15 @@ register_model_labels(
},
)
class TaskStatus(str):
class TaskStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
register_model_labels(
"TaskStatus",
{"en": "Task Status", "fr": "Statut de la tâche"},
@ -369,9 +595,14 @@ register_model_labels(
},
)
class DocumentExchange(BaseModel, ModelMixin):
documentsLabel: str = Field(description="Label for the set of documents")
documents: List[str] = Field(default_factory=list, description="List of document references")
documents: List[str] = Field(
default_factory=list, description="List of document references"
)
register_model_labels(
"DocumentExchange",
{"en": "Document Exchange", "fr": "Échange de documents"},
@ -381,20 +612,33 @@ register_model_labels(
},
)
class ActionItem(BaseModel, ModelMixin):
id: str = Field(..., description="Action ID")
execMethod: str = Field(..., description="Method to execute")
execAction: str = Field(..., description="Action to perform")
execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters")
execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents")
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)")
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
execParameters: Dict[str, Any] = Field(
default_factory=dict, description="Action parameters"
)
execResultLabel: Optional[str] = Field(
None, description="Label for the set of result documents"
)
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(
None, description="Expected document formats (optional)"
)
userMessage: Optional[str] = Field(
None, description="User-friendly message in user's language"
)
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status")
error: Optional[str] = Field(None, description="Error message if action failed")
retryCount: int = Field(default=0, description="Number of retries attempted")
retryMax: int = Field(default=3, description="Maximum number of retries")
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
timestamp: float = Field(..., description="When the action was executed (UTC timestamp in seconds)")
processingTime: Optional[float] = Field(
None, description="Processing time in seconds"
)
timestamp: float = Field(
..., description="When the action was executed (UTC timestamp in seconds)"
)
result: Optional[str] = Field(None, description="Result of the action")
def setSuccess(self, result: str = None) -> None:
@ -408,6 +652,8 @@ class ActionItem(BaseModel, ModelMixin):
"""Set the action as failed with error message"""
self.status = TaskStatus.FAILED
self.error = error_message
register_model_labels(
"ActionItem",
{"en": "Task Action", "fr": "Action de tâche"},
@ -417,7 +663,10 @@ register_model_labels(
"execAction": {"en": "Action", "fr": "Action"},
"execParameters": {"en": "Parameters", "fr": "Paramètres"},
"execResultLabel": {"en": "Result Label", "fr": "Label du résultat"},
"expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"},
"expectedDocumentFormats": {
"en": "Expected Document Formats",
"fr": "Formats de documents attendus",
},
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
"status": {"en": "Status", "fr": "Statut"},
"error": {"en": "Error", "fr": "Erreur"},
@ -429,12 +678,15 @@ register_model_labels(
},
)
class TaskResult(BaseModel, ModelMixin):
taskId: str = Field(..., description="Task ID")
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
success: bool = Field(..., description="Whether the task was successful")
feedback: Optional[str] = Field(None, description="Task feedback message")
error: Optional[str] = Field(None, description="Error message if task failed")
register_model_labels(
"TaskResult",
{"en": "Task Result", "fr": "Résultat de tâche"},
@ -447,22 +699,39 @@ register_model_labels(
},
)
class TaskItem(BaseModel, ModelMixin):
id: str = Field(..., description="Task ID")
workflowId: str = Field(..., description="Workflow ID")
userInput: str = Field(..., description="User input that triggered the task")
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
error: Optional[str] = Field(None, description="Error message if task failed")
startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)")
finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)")
actionList: List[ActionItem] = Field(default_factory=list, description="List of actions to execute")
startedAt: Optional[float] = Field(
None, description="When the task started (UTC timestamp in seconds)"
)
finishedAt: Optional[float] = Field(
None, description="When the task finished (UTC timestamp in seconds)"
)
actionList: List[ActionItem] = Field(
default_factory=list, description="List of actions to execute"
)
retryCount: int = Field(default=0, description="Number of retries attempted")
retryMax: int = Field(default=3, description="Maximum number of retries")
rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure")
dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on")
rollbackOnFailure: bool = Field(
default=True, description="Whether to rollback on failure"
)
dependencies: List[str] = Field(
default_factory=list, description="List of task IDs this task depends on"
)
feedback: Optional[str] = Field(None, description="Task feedback message")
processingTime: Optional[float] = Field(None, description="Total processing time in seconds")
resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values")
processingTime: Optional[float] = Field(
None, description="Total processing time in seconds"
)
resultLabels: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Map of result labels to their values"
)
register_model_labels(
"TaskItem",
{"en": "Task", "fr": "Tâche"},
@ -481,13 +750,18 @@ register_model_labels(
},
)
class TaskStep(BaseModel, ModelMixin):
id: str
objective: str
dependencies: Optional[list[str]] = Field(default_factory=list)
success_criteria: Optional[list[str]] = Field(default_factory=list)
estimated_complexity: Optional[str] = None
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
userMessage: Optional[str] = Field(
None, description="User-friendly message in user's language"
)
register_model_labels(
"TaskStep",
{"en": "Task Step", "fr": "Étape de tâche"},
@ -496,23 +770,45 @@ register_model_labels(
"objective": {"en": "Objective", "fr": "Objectif"},
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
"success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"},
"estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"},
"estimated_complexity": {
"en": "Estimated Complexity",
"fr": "Complexité estimée",
},
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
},
)
class TaskHandover(BaseModel, ModelMixin):
taskId: str = Field(description="Target task ID")
sourceTask: Optional[str] = Field(None, description="Source task ID")
inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents")
outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents")
inputDocuments: List[DocumentExchange] = Field(
default_factory=list, description="Available input documents"
)
outputDocuments: List[DocumentExchange] = Field(
default_factory=list, description="Produced output documents"
)
context: Dict[str, Any] = Field(default_factory=dict, description="Task context")
previousResults: List[str] = Field(default_factory=list, description="Previous result summaries")
improvements: List[str] = Field(default_factory=list, description="Improvement suggestions")
workflowSummary: Optional[str] = Field(None, description="Summarized workflow context")
messageHistory: List[str] = Field(default_factory=list, description="Key message summaries")
timestamp: float = Field(..., description="When the handover was created (UTC timestamp in seconds)")
handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow")
previousResults: List[str] = Field(
default_factory=list, description="Previous result summaries"
)
improvements: List[str] = Field(
default_factory=list, description="Improvement suggestions"
)
workflowSummary: Optional[str] = Field(
None, description="Summarized workflow context"
)
messageHistory: List[str] = Field(
default_factory=list, description="Key message summaries"
)
timestamp: float = Field(
..., description="When the handover was created (UTC timestamp in seconds)"
)
handoverType: str = Field(
default="task", description="Type of handover: task, phase, or workflow"
)
register_model_labels(
"TaskHandover",
{"en": "Task Handover", "fr": "Transfert de tâche"},
@ -531,9 +827,10 @@ register_model_labels(
},
)
class TaskContext(BaseModel, ModelMixin):
task_step: TaskStep
workflow: Optional['ChatWorkflow'] = None
workflow: Optional["ChatWorkflow"] = None
workflow_id: Optional[str] = None
available_documents: Optional[str] = "No documents available"
available_connections: Optional[list[str]] = Field(default_factory=list)
@ -562,6 +859,7 @@ class TaskContext(BaseModel, ModelMixin):
self.improvements = []
self.improvements.append(improvement)
class ReviewContext(BaseModel, ModelMixin):
task_step: TaskStep
task_actions: Optional[list] = Field(default_factory=list)
@ -570,6 +868,7 @@ class ReviewContext(BaseModel, ModelMixin):
workflow_id: Optional[str] = None
previous_results: Optional[list[str]] = Field(default_factory=list)
class ReviewResult(BaseModel, ModelMixin):
status: str
reason: Optional[str] = None
@ -579,7 +878,11 @@ class ReviewResult(BaseModel, ModelMixin):
met_criteria: Optional[list[str]] = Field(default_factory=list)
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
confidence: Optional[float] = 0.5
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
userMessage: Optional[str] = Field(
None, description="User-friendly message in user's language"
)
register_model_labels(
"ReviewResult",
{"en": "Review Result", "fr": "Résultat de l'évaluation"},
@ -596,10 +899,15 @@ register_model_labels(
},
)
class TaskPlan(BaseModel, ModelMixin):
overview: str
tasks: list[TaskStep]
userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan")
userMessage: Optional[str] = Field(
None, description="Overall user-friendly message for the task plan"
)
register_model_labels(
"TaskPlan",
{"en": "Task Plan", "fr": "Plan de tâches"},
@ -613,10 +921,16 @@ register_model_labels(
# Resolve forward references
TaskContext.update_forward_refs()
class PromptPlaceholder(BaseModel, ModelMixin):
label: str
content: str
summaryAllowed: bool = Field(default=False, description="Whether host may summarize content before sending to AI")
summaryAllowed: bool = Field(
default=False,
description="Whether host may summarize content before sending to AI",
)
register_model_labels(
"PromptPlaceholder",
{"en": "Prompt Placeholder", "fr": "Espace réservé d'invite"},
@ -627,11 +941,15 @@ register_model_labels(
},
)
class PromptBundle(BaseModel, ModelMixin):
prompt: str
placeholders: List[PromptPlaceholder] = Field(default_factory=list)
register_model_labels(
"PromptBundle", {"en": "Prompt Bundle", "fr": "Lot d'invite"},
"PromptBundle",
{"en": "Prompt Bundle", "fr": "Lot d'invite"},
{
"prompt": {"en": "Prompt", "fr": "Invite"},
"placeholders": {"en": "Placeholders", "fr": "Espaces réservés"},

View file

@ -0,0 +1,216 @@
"""Chatbot API models for request/response handling."""
from typing import List, Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import register_model_labels, ModelMixin
# Chatbot API Models
class MessageItem(BaseModel, ModelMixin):
"""Individual message in a thread"""
role: str = Field(..., description="Message role (user or assistant)")
content: str = Field(..., description="Message content")
class ChatMessageRequest(BaseModel, ModelMixin):
"""Request model for posting a chat message"""
thread_id: Optional[str] = Field(
None, description="Thread ID (creates new thread if not provided)"
)
message: str = Field(..., description="User message content")
tools: Optional[List[str]] = Field(
None,
description="List of tool IDs to use. If not provided, all user's tools will be used",
)
class ChatMessageResponse(BaseModel, ModelMixin):
"""Response model for posting a chat message"""
thread_id: str = Field(..., description="Thread ID")
messages: List[MessageItem] = Field(..., description="All messages in thread")
class ThreadSummary(BaseModel, ModelMixin):
"""Summary of a chat thread for list view"""
thread_id: str = Field(..., description="Thread ID")
thread_name: str = Field(..., description="Thread name")
date_created: float = Field(..., description="Thread creation timestamp")
date_updated: float = Field(..., description="Thread last updated timestamp")
class ThreadListResponse(BaseModel, ModelMixin):
"""Response model for listing all threads"""
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
class ThreadDetail(BaseModel, ModelMixin):
"""Detailed view of a single thread"""
thread_id: str = Field(..., description="Thread ID")
date_created: float = Field(..., description="Thread creation timestamp")
date_updated: float = Field(..., description="Thread last updated timestamp")
messages: List[MessageItem] = Field(
..., description="All messages in chronological order"
)
class RenameThreadRequest(BaseModel, ModelMixin):
"""Request model for renaming a thread"""
new_name: str = Field(..., description="New name for the thread")
class DeleteResponse(BaseModel, ModelMixin):
"""Response model for delete operations"""
message: str = Field(..., description="Confirmation message")
thread_id: str = Field(..., description="Deleted thread ID")
# Tool Management Models
class ToolInfo(BaseModel, ModelMixin):
"""Information about a chatbot tool"""
id: str = Field(..., description="Tool UUID")
tool_id: str = Field(
..., description="Tool identifier (e.g., 'shared.tavily_search')"
)
name: str = Field(..., description="Tool function name")
label: str = Field(..., description="Display label for the tool")
category: str = Field(..., description="Tool category (shared or customer)")
description: str = Field(..., description="Tool description")
is_active: bool = Field(..., description="Whether the tool is active")
date_created: float = Field(..., description="Creation timestamp")
date_updated: float = Field(..., description="Last update timestamp")
class ToolListResponse(BaseModel, ModelMixin):
"""Response model for listing all tools"""
tools: List[ToolInfo] = Field(..., description="List of available tools")
class GrantToolRequest(BaseModel, ModelMixin):
"""Request model for granting a tool to a user"""
user_id: str = Field(..., description="User ID to grant the tool to")
tool_id: str = Field(..., description="Tool UUID from tools table")
class GrantToolResponse(BaseModel, ModelMixin):
"""Response model after granting a tool"""
message: str = Field(..., description="Confirmation message")
user_id: str = Field(..., description="User ID")
tool_id: str = Field(..., description="Tool UUID")
class RevokeToolRequest(BaseModel, ModelMixin):
"""Request model for revoking a tool from a user"""
user_id: str = Field(..., description="User ID to revoke the tool from")
tool_id: str = Field(..., description="Tool UUID from tools table")
class RevokeToolResponse(BaseModel, ModelMixin):
"""Response model after revoking a tool"""
message: str = Field(..., description="Confirmation message")
user_id: str = Field(..., description="User ID")
tool_id: str = Field(..., description="Tool UUID")
class UpdateToolRequest(BaseModel, ModelMixin):
"""Request model for updating a tool's label and description"""
label: Optional[str] = Field(None, description="New label for the tool")
description: Optional[str] = Field(None, description="New description for the tool")
class UpdateToolResponse(BaseModel, ModelMixin):
"""Response model after updating a tool"""
message: str = Field(..., description="Confirmation message")
tool_id: str = Field(..., description="Tool UUID")
updated_fields: List[str] = Field(..., description="List of updated field names")
# Register model labels for internationalization
register_model_labels(
"MessageItem",
{"en": "Message Item", "fr": "Élément de message"},
{
"role": {"en": "Role", "fr": "Rôle"},
"content": {"en": "Content", "fr": "Contenu"},
},
)
register_model_labels(
"ChatMessageRequest",
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"message": {"en": "Message", "fr": "Message"},
},
)
register_model_labels(
"ChatMessageResponse",
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"messages": {"en": "Messages", "fr": "Messages"},
},
)
register_model_labels(
"ThreadSummary",
{"en": "Thread Summary", "fr": "Résumé du fil"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"thread_name": {"en": "Thread Name", "fr": "Nom du fil"},
"date_created": {"en": "Date Created", "fr": "Date de création"},
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
},
)
register_model_labels(
"ThreadListResponse",
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
{
"threads": {"en": "Threads", "fr": "Fils"},
},
)
register_model_labels(
"ThreadDetail",
{"en": "Thread Detail", "fr": "Détail du fil"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"date_created": {"en": "Date Created", "fr": "Date de création"},
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
"messages": {"en": "Messages", "fr": "Messages"},
},
)
register_model_labels(
"RenameThreadRequest",
{"en": "Rename Thread Request", "fr": "Demande de renommage de fil"},
{
"new_name": {"en": "New Name", "fr": "Nouveau nom"},
},
)
register_model_labels(
"DeleteResponse",
{"en": "Delete Response", "fr": "Réponse de suppression"},
{
"message": {"en": "Message", "fr": "Message"},
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
},
)

View file

@ -1,7 +1,7 @@
"""Security models: Token and AuthEvent."""
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from modules.shared.attributeUtils import register_model_labels, ModelMixin
from modules.shared.timezoneUtils import get_utc_timestamp
from .datamodelUam import AuthAuthority
@ -13,25 +13,44 @@ class TokenStatus(str, Enum):
ACTIVE = "active"
REVOKED = "revoked"
class Token(BaseModel, ModelMixin):
id: Optional[str] = None
userId: str
authority: AuthAuthority
connectionId: Optional[str] = Field(None, description="ID of the connection this token belongs to")
connectionId: Optional[str] = Field(
None, description="ID of the connection this token belongs to"
)
tokenAccess: str
tokenType: str = "bearer"
expiresAt: float = Field(description="When the token expires (UTC timestamp in seconds)")
expiresAt: float = Field(
description="When the token expires (UTC timestamp in seconds)"
)
tokenRefresh: Optional[str] = None
createdAt: Optional[float] = Field(None, description="When the token was created (UTC timestamp in seconds)")
status: TokenStatus = Field(default=TokenStatus.ACTIVE, description="Token status: active/revoked")
revokedAt: Optional[float] = Field(None, description="When the token was revoked (UTC timestamp in seconds)")
revokedBy: Optional[str] = Field(None, description="User ID who revoked the token (admin/self)")
createdAt: Optional[float] = Field(
None, description="When the token was created (UTC timestamp in seconds)"
)
status: TokenStatus = Field(
default=TokenStatus.ACTIVE, description="Token status: active/revoked"
)
revokedAt: Optional[float] = Field(
None, description="When the token was revoked (UTC timestamp in seconds)"
)
revokedBy: Optional[str] = Field(
None, description="User ID who revoked the token (admin/self)"
)
reason: Optional[str] = Field(None, description="Optional revocation reason")
sessionId: Optional[str] = Field(None, description="Logical session grouping for logout revocation")
mandateId: Optional[str] = Field(None, description="Mandate ID for tenant scoping of the token")
sessionId: Optional[str] = Field(
None, description="Logical session grouping for logout revocation"
)
mandateId: Optional[str] = Field(
None, description="Mandate ID for tenant scoping of the token"
)
class Config:
use_enum_values = True
register_model_labels(
"Token",
{"en": "Token", "fr": "Jeton"},
@ -54,15 +73,64 @@ register_model_labels(
},
)
class AuthEvent(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the auth event", frontend_type="text", frontend_readonly=True, frontend_required=False)
userId: str = Field(description="ID of the user this event belongs to", frontend_type="text", frontend_readonly=True, frontend_required=True)
eventType: str = Field(description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')", frontend_type="text", frontend_readonly=True, frontend_required=True)
timestamp: float = Field(default_factory=get_utc_timestamp, description="Unix timestamp when the event occurred", frontend_type="datetime", frontend_readonly=True, frontend_required=True)
ipAddress: Optional[str] = Field(default=None, description="IP address from which the event originated", frontend_type="text", frontend_readonly=True, frontend_required=False)
userAgent: Optional[str] = Field(default=None, description="User agent string from the request", frontend_type="text", frontend_readonly=True, frontend_required=False)
success: bool = Field(default=True, description="Whether the authentication event was successful", frontend_type="boolean", frontend_readonly=True, frontend_required=True)
details: Optional[str] = Field(default=None, description="Additional details about the event", frontend_type="text", frontend_readonly=True, frontend_required=False)
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the auth event",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
userId: str = Field(
description="ID of the user this event belongs to",
frontend_type="text",
frontend_readonly=True,
frontend_required=True,
)
eventType: str = Field(
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
frontend_type="text",
frontend_readonly=True,
frontend_required=True,
)
timestamp: float = Field(
default_factory=get_utc_timestamp,
description="Unix timestamp when the event occurred",
frontend_type="datetime",
frontend_readonly=True,
frontend_required=True,
)
ipAddress: Optional[str] = Field(
default=None,
description="IP address from which the event originated",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
userAgent: Optional[str] = Field(
default=None,
description="User agent string from the request",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
success: bool = Field(
default=True,
description="Whether the authentication event was successful",
frontend_type="boolean",
frontend_readonly=True,
frontend_required=True,
)
details: Optional[str] = Field(
default=None,
description="Additional details about the event",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
register_model_labels(
"AuthEvent",
{"en": "Authentication Event", "fr": "Événement d'authentification"},
@ -77,5 +145,3 @@ register_model_labels(
"details": {"en": "Details", "fr": "Détails"},
},
)

View file

@ -0,0 +1 @@
"""Contains all tools available for the chatbot to use."""

View file

@ -0,0 +1 @@
"""Tools that are shared between multiple customers go here."""

View file

@ -0,0 +1,208 @@
"""Althaus Database Query Tool for LangGraph.
This tool provides database query capabilities for the Althaus database
via an external REST API. Only SELECT queries are allowed.
"""
import logging
import asyncio
import re
from typing import Annotated
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
async def _mock_api_call(*, sql_query: str) -> dict:
"""Mock the external REST API call to Althaus database.
Args:
sql_query: The SQL SELECT query to execute
Returns:
A dictionary containing the query results with columns and rows
"""
# Simulate network delay
await asyncio.sleep(0.5)
# Mock response data based on common query patterns
if "users" in sql_query.lower():
return {
"columns": ["id", "username", "email", "created_at"],
"rows": [
[1, "john_doe", "john@example.com", "2024-01-15"],
[2, "jane_smith", "jane@example.com", "2024-02-20"],
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
],
"row_count": 3,
}
elif "products" in sql_query.lower():
return {
"columns": ["product_id", "name", "price", "stock"],
"rows": [
[101, "Widget A", 29.99, 150],
[102, "Widget B", 39.99, 75],
[103, "Widget C", 19.99, 200],
],
"row_count": 3,
}
elif "orders" in sql_query.lower():
return {
"columns": ["order_id", "customer_id", "total", "status"],
"rows": [
[5001, 1, 129.99, "completed"],
[5002, 2, 89.50, "pending"],
[5003, 1, 199.99, "shipped"],
],
"row_count": 3,
}
else:
# Generic response for other queries
return {
"columns": ["id", "value", "description"],
"rows": [
[1, "Sample 1", "First sample entry"],
[2, "Sample 2", "Second sample entry"],
],
"row_count": 2,
}
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
"""Validate that the query is a SELECT statement only.
Args:
sql_query: The SQL query to validate
Returns:
A tuple of (is_valid, error_message)
"""
# Remove leading/trailing whitespace and convert to lowercase for checking
normalized_query = sql_query.strip().lower()
# Check if query starts with SELECT
if not normalized_query.startswith("select"):
return False, "Query must be a SELECT statement"
# Check for dangerous keywords that should not be in a SELECT query
dangerous_keywords = [
"insert",
"update",
"delete",
"drop",
"create",
"alter",
"truncate",
"grant",
"revoke",
"exec",
"execute",
]
for keyword in dangerous_keywords:
# Use word boundary to match whole words only
if re.search(rf"\b{keyword}\b", normalized_query):
return False, f"Query contains forbidden keyword: {keyword.upper()}"
return True, ""
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
"""Format query results into a readable string.
Args:
columns: List of column names
rows: List of row data
row_count: Total number of rows
Returns:
Formatted string representation of the results
"""
if row_count == 0:
return "Query executed successfully but returned no results."
# Calculate column widths
col_widths = [len(str(col)) for col in columns]
for row in rows:
for i, cell in enumerate(row):
col_widths[i] = max(col_widths[i], len(str(cell)))
# Build header
header_parts = []
for col, width in zip(columns, col_widths):
header_parts.append(str(col).ljust(width))
header = " | ".join(header_parts)
separator = "-" * len(header)
# Build rows
row_lines = []
for row in rows:
row_parts = []
for cell, width in zip(row, col_widths):
row_parts.append(str(cell).ljust(width))
row_lines.append(" | ".join(row_parts))
# Combine all parts
result_parts = [
f"Query returned {row_count} row(s):\n",
header,
separator,
"\n".join(row_lines),
]
return "\n".join(result_parts)
@tool
async def query_althaus_database(
sql_query: Annotated[
str, "The SQL SELECT query to execute against the Althaus database"
],
) -> str:
"""Execute a SELECT query against the Althaus database via REST API.
Use this tool to query data from the Althaus database. Only SELECT statements
are allowed for security reasons. The query will be forwarded to an external
REST API and the results will be returned in a formatted table.
Args:
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
Returns:
A formatted string containing the query results with columns and rows
"""
try:
# Validate the query
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
if not is_valid:
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
return f"Error: {error_msg}"
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
# Mock the external REST API call
# In production, this would be replaced with actual REST API call:
# response = await httpx.AsyncClient().post(
# "https://api.althaus.example.com/query",
# json={"query": sql_query},
# headers={"Authorization": f"Bearer {api_key}"}
# )
# result = response.json()
result = await _mock_api_call(sql_query=sql_query)
# Format and return results
formatted_output = _format_results(
columns=result["columns"],
rows=result["rows"],
row_count=result["row_count"],
)
logger.info(
f"Query completed successfully, returned {result['row_count']} row(s)"
)
return formatted_output
except Exception as e:
logger.error(f"Error in query_althaus_database tool: {str(e)}")
return f"Error executing query: {str(e)}"

View file

@ -0,0 +1,362 @@
"""Power BI Query Tool for LangGraph.
This tool provides DAX query capabilities for Power BI datasets
via the Power BI REST API. Only read-only queries are allowed.
"""
import logging
import asyncio
import os
import re
import functools
from typing import Annotated
import anyio
import httpx
from langchain_core.tools import tool
from msal import ConfidentialClientApplication, SerializableTokenCache
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
# Configuration constants - encapsulated in this file
POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "")
POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "")
POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "")
POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "")
POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg"
POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com"
POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"]
# Limit results to prevent excessive context usage
MAX_ROWS_LIMIT = 100
def _validate_environment() -> tuple[bool, str]:
"""Validate that all required environment variables are set.
Returns:
A tuple of (is_valid, error_message)
"""
missing = []
if not POWERBI_DATASET_ID:
missing.append("POWERBI_DATASET_ID")
if not POWERBI_CLIENT_ID:
missing.append("POWERBI_CLIENT_ID")
if not POWERBI_CLIENT_SECRET:
missing.append("POWERBI_CLIENT_SECRET")
if not POWERBI_TENANT_ID:
missing.append("POWERBI_TENANT_ID")
if missing:
return False, f"Missing required environment variables: {', '.join(missing)}"
return True, ""
def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]:
"""Validate that the query is a valid DAX query.
Args:
dax_query: The DAX query to validate
Returns:
A tuple of (is_valid, error_message)
"""
# Remove leading/trailing whitespace
normalized_query = dax_query.strip()
if not normalized_query:
return False, "Query cannot be empty"
# DAX queries typically start with EVALUATE, DEFINE, or are table expressions
# We'll be lenient and just check it's not trying to do something dangerous
# DAX is read-only by nature, but we validate structure
# Check for minimum length
if len(normalized_query) < 5:
return False, "Query is too short to be valid"
return True, ""
def _get_access_token_sync(
*,
tenant_id: str,
client_id: str,
client_secret: str,
authority_base: str = POWERBI_AUTHORITY_BASE,
cache: SerializableTokenCache | None = None,
) -> str:
"""Get Power BI access token using MSAL (synchronous).
Args:
tenant_id: Azure AD tenant ID
client_id: Application client ID
client_secret: Application client secret
authority_base: Azure AD authority base URL
cache: Optional token cache for reuse
Returns:
Access token string
Raises:
RuntimeError: If token acquisition fails
"""
authority = f"{authority_base}/{tenant_id}"
app = ConfidentialClientApplication(
client_id=client_id,
authority=authority,
client_credential=client_secret,
token_cache=cache,
)
# Try cache first; fall back to client credentials
result = app.acquire_token_silent(
POWERBI_SCOPE, account=None
) or app.acquire_token_for_client(scopes=POWERBI_SCOPE)
if "access_token" not in result:
raise RuntimeError(
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
)
return result["access_token"]
async def _get_access_token_async(
*,
tenant_id: str,
client_id: str,
client_secret: str,
**kwargs,
) -> str:
"""Get Power BI access token using MSAL (asynchronous).
Args:
tenant_id: Azure AD tenant ID
client_id: Application client ID
client_secret: Application client secret
**kwargs: Additional arguments for _get_access_token_sync
Returns:
Access token string
"""
# Create a partial function with arguments pre-filled
func = functools.partial(
_get_access_token_sync,
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
**kwargs,
)
# Offload the blocking MSAL HTTP call to a worker thread
return await anyio.to_thread.run_sync(func)
async def _execute_dax_query(
*, dax_query: str, dataset_id: str, access_token: str
) -> dict:
"""Execute a DAX query against Power BI dataset.
Args:
dax_query: The DAX query to execute
dataset_id: Power BI dataset ID
access_token: Access token for authentication
Returns:
Dictionary containing query results
Raises:
RuntimeError: If query execution fails
"""
url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries"
body = {
"queries": [{"query": dax_query}],
"serializerSettings": {"includeNulls": True},
}
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(url, headers=headers, json=body)
if resp.status_code != 200:
raise RuntimeError(
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
)
payload = resp.json()
try:
rows = payload["results"][0]["tables"][0]["rows"]
except (KeyError, IndexError) as e:
raise RuntimeError("Unexpected executeQueries response structure") from e
# Extract column names from the first row if available
if rows:
columns = list(rows[0].keys())
else:
columns = []
return {"columns": columns, "rows": rows}
def _strip_table_qualifier(*, column_name: str) -> str:
"""Strip table qualifier from column name.
Power BI often returns columns as 'Table[Column]'. This strips to 'Column'.
Args:
column_name: The column name to process
Returns:
Processed column name
"""
if "[" in column_name and column_name.endswith("]"):
return column_name.split("[", 1)[1][:-1]
return column_name
def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str:
"""Format query results into a readable string.
Args:
columns: List of column names
rows: List of row data (as dictionaries)
max_rows: Maximum number of rows to display
Returns:
Formatted string representation of the results
"""
total_rows = len(rows)
if total_rows == 0:
return "Query executed successfully but returned no results."
# Strip table qualifiers from column names
clean_columns = [_strip_table_qualifier(column_name=col) for col in columns]
# Limit rows to max_rows
display_rows = rows[:max_rows]
truncated = total_rows > max_rows
# Calculate column widths
col_widths = [len(str(col)) for col in clean_columns]
for row in display_rows:
for i, col in enumerate(columns):
value = row.get(col, "")
col_widths[i] = max(col_widths[i], len(str(value)))
# Build header
header_parts = []
for col, width in zip(clean_columns, col_widths):
header_parts.append(str(col).ljust(width))
header = " | ".join(header_parts)
separator = "-" * len(header)
# Build rows
row_lines = []
for row in display_rows:
row_parts = []
for col, width in zip(columns, col_widths):
value = row.get(col, "")
row_parts.append(str(value).ljust(width))
row_lines.append(" | ".join(row_parts))
# Combine all parts
result_parts = [
f"Query returned {total_rows} row(s):",
]
if truncated:
result_parts.append(
f"(Results limited to {max_rows} rows for context efficiency)\n"
)
else:
result_parts.append("")
result_parts.extend([header, separator, "\n".join(row_lines)])
return "\n".join(result_parts)
@tool
async def query_powerbi_data(
dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"],
) -> str:
"""Execute a DAX query against the Power BI dataset to access warehouse inventory data.
This tool provides access to a Power BI table called 'data_full' which contains
articles available in the warehouse of the user. Use DAX (Data Analysis Expressions)
queries to retrieve and analyze this inventory data.
Available table:
- 'data_full': Contains warehouse inventory articles and their details
Common query patterns:
- View all data: EVALUATE 'data_full'
- With filter: EVALUATE FILTER('data_full', [Column] = "Value")
- Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC)
- Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2]))
Results are limited to 100 rows maximum for efficiency.
Args:
dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'")
Returns:
A formatted string containing the query results with columns and rows
"""
try:
# Validate environment configuration
is_valid_env, error_msg = _validate_environment()
if not is_valid_env:
logger.error(f"Environment validation failed: {error_msg}")
return f"Configuration Error: {error_msg}"
# Validate the query
is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query)
if not is_valid_query:
logger.warning(f"Invalid query attempt: {dax_query[:100]}...")
return f"Query Validation Error: {error_msg}"
logger.info(f"Executing Power BI query: {dax_query[:100]}...")
# Get access token
access_token = await _get_access_token_async(
tenant_id=POWERBI_TENANT_ID,
client_id=POWERBI_CLIENT_ID,
client_secret=POWERBI_CLIENT_SECRET,
)
# Execute the query
result = await _execute_dax_query(
dax_query=dax_query,
dataset_id=POWERBI_DATASET_ID,
access_token=access_token,
)
# Format and return results
formatted_output = _format_results(
columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT
)
logger.info(
f"Query completed successfully, returned {len(result['rows'])} row(s)"
)
return formatted_output
except RuntimeError as e:
logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}")
return f"Error executing query: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}")
return f"Unexpected error: {str(e)}"

View file

@ -0,0 +1,7 @@
"""Shared tools available across all chatbot implementations."""
from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import (
tavily_search,
)
__all__ = ["tavily_search"]

View file

@ -0,0 +1,24 @@
"""Tool for sending streaming status updates to users."""
from langchain_core.tools import tool
@tool
def send_streaming_message(message: str) -> str:
"""Send a streaming message to the user to provide updates during processing.
Use this tool to send short status updates to the user while you are working
on their request. This helps keep the user informed about what you are doing.
Args:
message: A short message describing what you are currently doing.
Examples: "Searching database for relevant information..."
"Analyzing search results..."
"Processing your request..."
Returns:
A confirmation that the message was sent.
"""
# This tool doesn't actually do anything - it's just for the AI to signal
# what it's doing to the frontend via the tool call mechanism
return f"Status update sent: {message}"

View file

@ -0,0 +1,55 @@
"""Tavily Search Tool for LangGraph.
This tool provides web search capabilities using the Tavily API.
"""
import logging
from typing import Annotated
from langchain_core.tools import tool
from modules.connectors.connectorAiTavily import ConnectorWeb
logger = logging.getLogger(__name__)
@tool
async def tavily_search(
query: Annotated[str, "The search query to look up on the web"],
) -> str:
"""Search the web using Tavily API.
Use this tool to search for current information, news, or any web content.
The tool returns relevant search results including titles and URLs.
Args:
query: The search query string
Returns:
A formatted string containing search results with titles and URLs
"""
try:
# Create connector instance
connector = await ConnectorWeb.create()
# Perform search with default parameters
results = await connector._search(
query=query,
max_results=5,
search_depth="basic",
include_answer=True,
include_raw_content=False,
)
# Format results
if not results:
return f"No results found for query: {query}"
formatted_results = [f"Search results for '{query}':\n"]
for i, result in enumerate(results, 1):
formatted_results.append(f"{i}. {result.title}")
formatted_results.append(f" URL: {result.url}\n")
return "\n".join(formatted_results)
except Exception as e:
logger.error(f"Error in tavily_search tool: {str(e)}")
return f"Error performing search: {str(e)}"

View file

@ -0,0 +1,197 @@
from typing import AsyncIterator
import uuid
from datetime import datetime, timezone
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint
class Base(DeclarativeBase):
pass
# Tools Table
class Tool(Base):
"""Available chatbot tools.
Stores information about all available tools that can be assigned to users.
Each tool has a unique tool_id that corresponds to the registry tool_id.
"""
__tablename__ = "tools"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
label: Mapped[str] = mapped_column(String(255), nullable=False)
category: Mapped[str] = mapped_column(String(50), nullable=False)
description: Mapped[str] = mapped_column(String(1000), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
date_created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
date_updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# User-Tool Mapping Table
class UserToolMapping(Base):
"""Mapping of users to their available tools.
Many-to-many relationship between users and tools.
- One user can have multiple tools
- One tool can be assigned to multiple users
The combination of user_id and tool_id is unique.
"""
__tablename__ = "user_tools"
__table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),)
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
date_granted: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
date_updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# User Thread Mapping Table
class UserThreadMapping(Base):
"""Mapping of users to their chat threads.
Used to keep track of which user owns which chat thread.
Also stores meta data like thread name.
1:N relationship between user and thread. A thread belongs to exactly one user.
A user can have multiple threads.
Thread_id is unique in the table.
"""
__tablename__ = "user_threads"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
thread_name: Mapped[str] = mapped_column(String(255), nullable=False)
date_created: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
date_updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# Dependency that pulls the sessionmaker off app.state
# This is set in app.py on startup in @asynccontextmanager
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
SessionLocal: async_sessionmaker[AsyncSession] = (
request.app.state.checkpoint_sessionmaker
)
async with SessionLocal() as session:
yield session
# Optional helper to init tables at startup (demo only)
async def init_models(engine) -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def sync_tools_from_registry(session: AsyncSession) -> None:
"""Sync tools from tool registry to database.
This function:
- Adds new tools from the registry to the database
- Updates existing tools with current registry information
- Marks tools not present in the registry as inactive
Should be called on application startup after database initialization.
Args:
session: Active database session
"""
import logging
from sqlalchemy import select
from modules.features.chatBot.utils.toolRegistry import get_registry
logger = logging.getLogger(__name__)
logger.info("Syncing tools from registry to database...")
# Get all tools from the registry
registry = get_registry()
registry_tools = registry.get_all_tools()
# Create a set of tool_ids from the registry
registry_tool_ids = {tool.tool_id for tool in registry_tools}
logger.info(f"Found {len(registry_tools)} tools in registry")
# Get all existing tools from the database
result = await session.execute(select(Tool))
db_tools = result.scalars().all()
db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools}
logger.info(f"Found {len(db_tools)} tools in database")
# Track changes
added_count = 0
updated_count = 0
deactivated_count = 0
# Sync tools from registry to database
for registry_tool in registry_tools:
if registry_tool.tool_id in db_tools_by_tool_id:
# Tool exists - update it
# Preserve label and description (user-editable fields)
db_tool = db_tools_by_tool_id[registry_tool.tool_id]
db_tool.name = registry_tool.name
db_tool.category = registry_tool.category
db_tool.is_active = True
db_tool.date_updated = datetime.now(timezone.utc)
updated_count += 1
logger.debug(f"Updated tool: {registry_tool.tool_id}")
else:
# Tool doesn't exist - create it
new_tool = Tool(
tool_id=registry_tool.tool_id,
name=registry_tool.name,
label=registry_tool.tool_id, # Use tool_id as label per spec
category=registry_tool.category,
description=registry_tool.description or "",
is_active=True,
)
session.add(new_tool)
added_count += 1
logger.debug(f"Added new tool: {registry_tool.tool_id}")
# Mark tools not in registry as inactive
for db_tool in db_tools:
if db_tool.tool_id not in registry_tool_ids and db_tool.is_active:
db_tool.is_active = False
db_tool.date_updated = datetime.now(timezone.utc)
deactivated_count += 1
logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}")
logger.info(
f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated"
)

View file

@ -0,0 +1 @@
"""Domain logic for chatbot functionality."""

View file

@ -0,0 +1,301 @@
"""Chatbot domain logic with LangGraph integration."""
from dataclasses import dataclass
from typing import Annotated, AsyncIterator, Any
import logging
from pydantic import BaseModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
from modules.features.chatBot.utils.toolRegistry import get_registry
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
class ChatState(BaseModel):
"""Represents the state of a chat session."""
messages: Annotated[list[BaseMessage], add_messages]
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
"""Map permission model names to LangChain ChatAnthropic models.
Args:
model_name: The model name from permissions (e.g., "claude_4_5")
Returns:
Configured ChatAnthropic instance
Raises:
ValueError: If the model name is not supported
"""
# Model name mapping
model_mapping = {
"claude_4_5": "claude-sonnet-4-5",
# Add more mappings as needed
}
anthropic_model = model_mapping.get(model_name)
if not anthropic_model:
logger.warning(
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
)
anthropic_model = "claude-4-5-sonnet"
return ChatAnthropic(
model=anthropic_model,
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
)
@dataclass
class Chatbot:
"""Represents a chatbot with LangGraph integration."""
model: Any
memory: Any
app: Any = None
system_prompt: str = "You are a helpful assistant."
context_window_size: int = 100000
@classmethod
async def create(
cls,
*,
model: Any,
memory: Any,
system_prompt: str,
tools: list,
context_window_size: int = 100000,
) -> "Chatbot":
"""Factory method to create and configure a Chatbot instance.
Args:
model: The chat model to use.
memory: The chat memory checkpointer to use.
system_prompt: The system prompt to initialize the chatbot.
tools: List of LangChain tools the chatbot can use.
context_window_size: Maximum tokens for context window.
Returns:
A configured Chatbot instance.
"""
instance = cls(
model=model,
memory=memory,
system_prompt=system_prompt,
context_window_size=context_window_size,
)
instance.app = instance._build_app(memory=memory, tools=tools)
return instance
def _build_app(
self, *, memory: Any, tools: list
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
"""Builds the chatbot application workflow using LangGraph.
Args:
memory: The chat memory checkpointer to use.
tools: The list of tools the chatbot can use.
Returns:
A compiled state graph representing the chatbot application.
"""
llm_with_tools = self.model.bind_tools(tools=tools)
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
"""Selects a window of messages that fit within the context window size.
Args:
msgs: The list of messages to select from.
Returns:
A list of messages that fit within the context window size.
"""
def approx_counter(items: list[BaseMessage]) -> int:
"""Approximate token counter for messages.
Args:
items: List of messages to count tokens for.
Returns:
Approximate number of tokens in the messages.
"""
return sum(len(getattr(m, "content", "") or "") for m in items)
return trim_messages(
msgs,
strategy="last",
token_counter=approx_counter,
max_tokens=self.context_window_size,
start_on="human",
end_on=("human", "tool"),
include_system=True,
)
def agent_node(state: ChatState) -> dict:
"""Agent node for the chatbot workflow.
Args:
state: The current chat state.
Returns:
The updated chat state after processing.
"""
# Select the message window to fit in context (trim if needed)
window = select_window(state.messages)
# Ensure the system prompt is present at the start
if not window or not isinstance(window[0], SystemMessage):
window = [SystemMessage(content=self.system_prompt)] + window
# Call the LLM with tools
response = llm_with_tools.invoke(window)
# Return the new state
return {"messages": [response]}
def should_continue(state: ChatState) -> str:
"""Determines whether to continue the workflow or end it.
This conditional edge is called after the agent node to decide
whether to continue to the tools node (if the last message contains
tool calls) or to end the workflow (if no tool calls are present).
Args:
state: The current chat state.
Returns:
The next node to transition to ("tools" or END).
"""
# Get the last message
last_message = state.messages[-1]
# Check if the last message contains tool calls
# If so, continue to the tools node; otherwise, end the workflow
return "tools" if getattr(last_message, "tool_calls", None) else END
# Compose the workflow
workflow = StateGraph(ChatState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", ToolNode(tools=tools))
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", "agent")
return workflow.compile(checkpointer=memory)
async def chat(
self, *, message: str, chat_id: str = "default"
) -> list[BaseMessage]:
"""Processes a chat message and returns the chat history.
Args:
message: The user message to process.
chat_id: The chat thread ID.
Returns:
The list of messages in the chat history.
"""
# Set the right thread ID for memory
config = {"configurable": {"thread_id": chat_id}}
# Single-turn chat (non-streaming)
result = await self.app.ainvoke(
{"messages": [HumanMessage(content=message)]}, config=config
)
# Extract and return the messages from the result
return result["messages"]
async def stream_events(
self, *, message: str, chat_id: str = "default"
) -> AsyncIterator[dict]:
"""Stream UI-focused events using astream_events v2.
Args:
message: The user message to process.
chat_id: Logical thread identifier; forwarded in the runnable config so
memory and tools are scoped per thread.
Yields:
dict: One of:
- ``{"type": "status", "label": str}`` for short progress updates.
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
where ``chat_history`` only includes ``user``/``assistant`` roles.
- ``{"type": "error", "message": str}`` if an exception occurs.
"""
# Thread-aware config for LangGraph/LangChain
config = {"configurable": {"thread_id": chat_id}}
def _is_root(ev: dict) -> bool:
"""Return True if the event is from the root run (v2: empty parent_ids)."""
return not ev.get("parent_ids")
try:
async for event in self.app.astream_events(
{"messages": [HumanMessage(content=message)]},
config=config,
version="v2",
):
etype = event.get("event")
ename = event.get("name") or ""
edata = event.get("data") or {}
# Stream human-readable progress via the special send_streaming_message tool
if etype == "on_tool_start" and ename == "send_streaming_message":
tool_in = edata.get("input") or {}
msg = tool_in.get("message")
if isinstance(msg, str) and msg.strip():
yield {"type": "status", "label": msg.strip()}
continue
# Emit the final payload when the root run finishes
if etype == "on_chain_end" and _is_root(event):
output_obj = edata.get("output")
# Extract message list from the graph's final output
final_msgs = ChatStreamingHelper.extract_messages_from_output(
output_obj=output_obj
)
# Normalize for the frontend (only user/assistant with text content)
chat_history_payload: list[dict] = []
for m in final_msgs:
if isinstance(m, BaseMessage):
d = ChatStreamingHelper.message_to_dict(msg=m)
elif isinstance(m, dict):
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
else:
continue
if d.get("role") in ("user", "assistant") and d.get("content"):
chat_history_payload.append(d)
yield {
"type": "final",
"response": {
"thread": chat_id,
"chat_history": chat_history_payload,
},
}
return
except Exception as exc:
# Emit a single error envelope and end the stream
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
yield {"type": "error", "message": f"Error processing request: {exc}"}

View file

@ -0,0 +1,239 @@
"""Streaming helper utilities for chat message processing and normalization."""
from typing import Any, Dict, List, Literal, Mapping, Optional
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
Role = Literal["user", "assistant", "system", "tool"]
class ChatStreamingHelper:
"""Pure helper methods for streaming and message normalization.
This class provides static utility methods for converting between different
message formats, extracting content, and normalizing message structures
for streaming chat applications.
"""
@staticmethod
def role_from_message(*, msg: BaseMessage) -> Role:
"""Extract the role from a BaseMessage instance.
Args:
msg: The BaseMessage instance to extract the role from.
Returns:
The role as a string literal: "user", "assistant", "system", or "tool".
Defaults to "assistant" if the message type is not recognized.
Examples:
>>> from langchain_core.messages import HumanMessage
>>> msg = HumanMessage(content="Hello")
>>> ChatStreamingHelper.role_from_message(msg=msg)
'user'
"""
if isinstance(msg, HumanMessage):
return "user"
if isinstance(msg, AIMessage):
return "assistant"
if isinstance(msg, SystemMessage):
return "system"
if isinstance(msg, ToolMessage):
return "tool"
return getattr(msg, "role", "assistant")
@staticmethod
def flatten_content(*, content: Any) -> str:
"""Convert complex content structures to plain text.
This method handles various content formats including strings, lists of
content parts, and dictionaries with text fields. It's designed to
normalize content from different message sources into a consistent
plain text format.
Args:
content: The content to flatten. Can be:
- str: Returned as-is after stripping whitespace
- list: Each item processed and joined with newlines
- dict: Text extracted from "text" or "content" fields
- None: Returns empty string
- Any other type: Converted to string
Returns:
The flattened content as a plain text string with whitespace stripped.
Examples:
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
>>> ChatStreamingHelper.flatten_content(content=content)
'Hello
nworld'
>>> content = {"text": "Simple message"}
>>> ChatStreamingHelper.flatten_content(content=content)
'Simple message'
"""
if content is None:
return ""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
if "text" in part and isinstance(part["text"], str):
parts.append(part["text"])
elif part.get("type") == "text" and isinstance(
part.get("text"), str
):
parts.append(part["text"])
elif "content" in part and isinstance(part["content"], str):
parts.append(part["content"])
else:
# Fallback for unknown dictionary structures
val = part.get("value")
if isinstance(val, str):
parts.append(val)
else:
parts.append(str(part))
return "\n".join(p.strip() for p in parts if p is not None)
if isinstance(content, dict):
if "text" in content and isinstance(content["text"], str):
return content["text"].strip()
if "content" in content and isinstance(content["content"], str):
return content["content"].strip()
return str(content).strip()
@staticmethod
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
"""Convert a BaseMessage instance to a dictionary for streaming output.
This method normalizes BaseMessage instances into a consistent dictionary
format suitable for JSON serialization and streaming to clients.
Args:
msg: The BaseMessage instance to convert.
Returns:
A dictionary containing:
- "role": The message role (user, assistant, system, tool)
- "content": The flattened message content as plain text
- "tool_calls": Tool calls if present (optional)
- "name": Message name if present (optional)
Examples:
>>> from langchain_core.messages import HumanMessage
>>> msg = HumanMessage(content="Hello there")
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
>>> result["role"]
'user'
>>> result["content"]
'Hello there'
"""
payload: Dict[str, Any] = {
"role": ChatStreamingHelper.role_from_message(msg=msg),
"content": ChatStreamingHelper.flatten_content(
content=getattr(msg, "content", "")
),
}
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
payload["tool_calls"] = tool_calls
name = getattr(msg, "name", None)
if name:
payload["name"] = name
return payload
@staticmethod
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
"""Convert a dictionary-shaped message to a normalized dictionary.
This method handles messages that come from serialized state and are
represented as dictionaries rather than BaseMessage instances. It
normalizes various dictionary formats into a consistent structure.
Args:
obj: The dictionary-shaped message to convert. Expected to contain
fields like "role", "type", "content", "text", etc.
Returns:
A normalized dictionary containing:
- "role": The message role (user, assistant, system, tool)
- "content": The flattened message content as plain text
- "tool_calls": Tool calls if present (optional)
- "name": Message name if present (optional)
Examples:
>>> obj = {"type": "human", "content": "Hello"}
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
>>> result["role"]
'user'
>>> result["content"]
'Hello'
"""
role: Optional[str] = obj.get("role")
if not role:
# Handle alternative type field mappings
typ = obj.get("type")
if typ in ("human", "user"):
role = "user"
elif typ in ("ai", "assistant"):
role = "assistant"
elif typ in ("system",):
role = "system"
elif typ in ("tool", "function"):
role = "tool"
content = obj.get("content")
if content is None and "text" in obj:
content = obj["text"]
out: Dict[str, Any] = {
"role": role or "assistant",
"content": ChatStreamingHelper.flatten_content(content=content),
}
if "tool_calls" in obj:
out["tool_calls"] = obj["tool_calls"]
if obj.get("name"):
out["name"] = obj["name"]
return out
@staticmethod
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
"""Extract messages from LangGraph output objects.
This method handles various output formats from LangGraph execution,
extracting the messages list from different possible structures.
Args:
output_obj: The output object from LangGraph execution. Can be:
- An object with a "messages" attribute
- A dictionary with a "messages" key
- Any other object (returns empty list)
Returns:
A list of extracted messages, or an empty list if no messages
are found or if the output object is None.
Examples:
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
>>> len(messages)
1
"""
if output_obj is None:
return []
# Try to parse dicts first
if isinstance(output_obj, dict):
msgs = output_obj.get("messages")
return msgs if isinstance(msgs, list) else []
# Then try to get messages attribute
msgs = getattr(output_obj, "messages", None)
return msgs if isinstance(msgs, list) else []

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,106 @@
"""PostgreSQL checkpointer utilities for LangGraph memory."""
import sys
import asyncio
import logging
from typing import Optional
# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply)
if sys.platform == 'win32':
try:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
except RuntimeError:
pass # Already set
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
from psycopg.rows import dict_row
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
# Global checkpointer instance
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
_connection_pool: Optional[AsyncConnectionPool] = None
async def initialize_checkpointer() -> None:
"""Initialize the PostgreSQL checkpointer for LangGraph.
This should be called during application startup.
Creates a connection pool and PostgresSaver instance.
"""
global _checkpointer_instance, _connection_pool
if _checkpointer_instance is not None:
logger.info("Checkpointer already initialized")
return
try:
# Get database configuration from environment
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
# Build connection string
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
# Create async connection pool
_connection_pool = AsyncConnectionPool(
conninfo=connection_string,
min_size=2,
max_size=10,
kwargs={"autocommit": True, "row_factory": dict_row},
)
# Initialize the connection pool
await _connection_pool.open()
# Create AsyncPostgresSaver with the pool
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
# Setup the checkpointer (creates tables if needed)
await _checkpointer_instance.setup()
logger.info("PostgreSQL checkpointer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
raise
async def close_checkpointer() -> None:
"""Close the checkpointer and connection pool.
This should be called during application shutdown.
"""
global _checkpointer_instance, _connection_pool
if _connection_pool is not None:
try:
await _connection_pool.close()
logger.info("PostgreSQL checkpointer connection pool closed")
except Exception as e:
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
_checkpointer_instance = None
_connection_pool = None
def get_checkpointer() -> AsyncPostgresSaver:
"""Get the global PostgreSQL checkpointer instance.
Returns:
The initialized AsyncPostgresSaver instance
Raises:
RuntimeError: If checkpointer is not initialized
"""
if _checkpointer_instance is None:
raise RuntimeError(
"PostgreSQL checkpointer not initialized. "
"Call initialize_checkpointer() during application startup."
)
return _checkpointer_instance

View file

@ -0,0 +1,39 @@
"""Mock permissions module for chatbot access control.
This module provides mock permission functions that will be replaced
with actual database-driven permissions in the future.
"""
from datetime import datetime
from modules.features.chatBot.utils.toolRegistry import get_registry
# TODO: Replace these mock implementations with actual database queries
def get_chatbot_tools(*, user_id: str) -> list[str]:
"""Get list of tool IDs that the chatbot can use for a given user."""
registry = get_registry()
return registry.list_tool_ids()
def get_chatbot_model(*, user_id: str) -> str:
"""Gets the chatbot model(s) a user is allowed to use."""
return "claude_4_5"
def get_system_prompt(*, user_id: str) -> str:
"""Get the system prompt for a user's chatbot session.
This is a mock implementation that returns a generic prompt with today's date.
In production, this will query the database for user-specific or role-specific prompts.
Args:
user_id: The unique identifier of the user
Returns:
The system prompt string with the current date
"""
current_date = datetime.now().strftime("%Y-%m-%d")
return f"You're a smart assistant. Today is {current_date}"

View file

@ -0,0 +1,305 @@
"""Tool registry for auto-discovering and managing chatbot tools.
This module provides a central registry that automatically discovers all tools
in the chatbotTools directory structure and provides methods to query them.
The registry is built in-memory at startup and does not require a database.
"""
import importlib
import inspect
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)
@dataclass
class ToolMetadata:
"""Metadata about a discovered chatbot tool.
Attributes:
tool_id: Unique identifier (e.g., 'shared.tavily_search')
name: Function name of the tool
category: Category of the tool ('shared' or 'customer')
description: Tool description from docstring
tool_instance: The actual LangChain tool instance
module_path: Full Python module path
"""
tool_id: str
name: str
category: str
description: str
tool_instance: BaseTool
module_path: str
def __str__(self) -> str:
"""Return a pretty-printed string representation for logging."""
return (
f"ToolMetadata(\n"
f" tool_id='{self.tool_id}',\n"
f" name='{self.name}',\n"
f" category='{self.category}',\n"
f" description='{self.description}',\n"
f" module_path='{self.module_path}'\n"
f")"
)
class ToolRegistry:
"""Central registry for all chatbot tools.
This class discovers and catalogs all tools decorated with @tool in the
chatbotTools directory structure. Tools are automatically discovered at
initialization by scanning the filesystem.
The registry provides methods to query tools by ID, category, or get all tools.
"""
def __init__(self) -> None:
"""Initialize an empty tool registry."""
self._tools: Dict[str, ToolMetadata] = {}
self._initialized: bool = False
def initialize(self) -> None:
"""Discover and register all tools from the chatbotTools directory.
This method scans both sharedTools and customerTools directories,
imports all tool*.py modules, and extracts functions decorated with @tool.
This method is idempotent - calling it multiple times has no effect
after the first initialization.
"""
if self._initialized:
logger.debug("Tool registry already initialized, skipping")
return
logger.info("Initializing tool registry...")
# Get base path to chatbotTools directory
base_path = Path(__file__).parent.parent / "chatbotTools"
if not base_path.exists():
logger.warning(f"chatbotTools directory not found at {base_path}")
self._initialized = True
return
# Discover tools in each category
self._discover_category(
category_path=base_path / "sharedTools", category="shared"
)
self._discover_category(
category_path=base_path / "customerTools", category="customer"
)
self._initialized = True
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
def _discover_category(self, *, category_path: Path, category: str) -> None:
"""Discover all tools in a specific category directory.
Args:
category_path: Path to the category directory (sharedTools or customerTools)
category: Category name ('shared' or 'customer')
"""
if not category_path.exists():
logger.warning(f"Category directory not found: {category_path}")
return
logger.debug(f"Discovering tools in category: {category}")
# Find all tool*.py files (excluding __init__.py)
tool_files = [
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
]
for tool_file in tool_files:
self._import_and_register_tools(
tool_file=tool_file, category=category, category_path=category_path
)
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
def _import_and_register_tools(
self, *, tool_file: Path, category: str, category_path: Path
) -> None:
"""Import a tool module and register all discovered tools.
Args:
tool_file: Path to the tool Python file
category: Category name ('shared' or 'customer')
category_path: Path to the category directory
"""
# Construct module name
module_name = (
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
)
try:
# Import the module
module = importlib.import_module(module_name)
# Find all BaseTool instances in the module
tools_found = 0
for name, obj in inspect.getmembers(module):
if isinstance(obj, BaseTool):
self._register_tool(
tool_instance=obj,
name=name,
category=category,
module_path=module_name,
)
tools_found += 1
if tools_found == 0:
logger.warning(f"No tools found in {module_name}")
else:
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
except ImportError as e:
logger.error(
f"Import error loading tools from {module_name}: {str(e)}. "
f"This tool will not be available."
)
except Exception as e:
logger.error(
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
)
def _register_tool(
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
) -> None:
"""Register a single tool in the registry.
Args:
tool_instance: The LangChain tool instance
name: Function name of the tool
category: Category name ('shared' or 'customer')
module_path: Full Python module path
"""
tool_id = f"{category}.{name}"
# Check for duplicate tool IDs
if tool_id in self._tools:
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
metadata = ToolMetadata(
tool_id=tool_id,
name=name,
category=category,
description=tool_instance.description or "",
tool_instance=tool_instance,
module_path=module_path,
)
self._tools[tool_id] = metadata
logger.debug(f"Registered tool: {tool_id}")
def get_all_tools(self) -> List[ToolMetadata]:
"""Get all registered tools.
Returns:
List of all tool metadata objects
"""
return list(self._tools.values())
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
"""Get a specific tool by its ID.
Args:
tool_id: The tool identifier (e.g., 'shared.tavily_search')
Returns:
Tool metadata if found, None otherwise
"""
return self._tools.get(tool_id)
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
"""Get all tools in a specific category.
Args:
category: Category name ('shared' or 'customer')
Returns:
List of tool metadata for the specified category
"""
return [t for t in self._tools.values() if t.category == category]
def list_tool_ids(self) -> List[str]:
"""Get a list of all registered tool IDs.
Returns:
List of tool ID strings
"""
return list(self._tools.keys())
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
"""Get actual tool instances for a list of tool IDs.
This is useful for filtering tools based on user permissions.
Args:
tool_ids: List of tool IDs to retrieve
Returns:
List of BaseTool instances for the specified IDs
"""
instances = []
for tool_id in tool_ids:
metadata = self.get_tool(tool_id=tool_id)
if metadata:
instances.append(metadata.tool_instance)
else:
logger.warning(f"Tool ID not found in registry: {tool_id}")
return instances
@property
def is_initialized(self) -> bool:
"""Check if the registry has been initialized.
Returns:
True if initialized, False otherwise
"""
return self._initialized
# Global registry instance
_registry: Optional[ToolRegistry] = None
def get_registry() -> ToolRegistry:
"""Get the global tool registry instance.
This function ensures the registry is initialized on first access.
Subsequent calls return the same instance.
Returns:
The global ToolRegistry instance
"""
global _registry
if _registry is None:
_registry = ToolRegistry()
if not _registry.is_initialized:
_registry.initialize()
return _registry
def reinitialize_registry() -> ToolRegistry:
"""Force reinitialize the tool registry.
This is useful for testing or when tools are added dynamically.
Returns:
The reinitialized ToolRegistry instance
"""
global _registry
_registry = ToolRegistry()
_registry.initialize()
return _registry

File diff suppressed because it is too large Load diff

View file

@ -21,13 +21,14 @@ os.makedirs(staticFolder, exist_ok=True)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="",
tags=["Administration"],
responses={404: {"description": "Not found"}}
prefix="", tags=["Administration"], responses={404: {"description": "Not found"}}
)
# Mount static files
router.mount("/static", StaticFiles(directory=str(staticFolder), html=True), name="static")
router.mount(
"/static", StaticFiles(directory=str(staticFolder), html=True), name="static"
)
@router.get("/")
@limiter.limit("30/minute")
@ -36,31 +37,40 @@ async def root(request: Request) -> Dict[str, str]:
# Validate required configuration values
allowedOrigins = APP_CONFIG.get("APP_ALLOWED_ORIGINS")
if not allowedOrigins:
raise HTTPException(status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required")
raise HTTPException(
status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required"
)
return {
"status": "online",
"message": "Data Platform API is active",
"allowedOrigins": f"Allowed origins are {allowedOrigins}"
"allowedOrigins": f"Allowed origins are {allowedOrigins}",
}
@router.get("/api/environment")
@limiter.limit("30/minute")
async def get_environment(request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, str]:
async def get_environment(
request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)
) -> Dict[str, str]:
"""Get environment configuration for frontend"""
# Validate required configuration values
apiBaseUrl = APP_CONFIG.get("APP_API_URL")
if not apiBaseUrl:
raise HTTPException(status_code=500, detail="APP_API_URL configuration is required")
raise HTTPException(
status_code=500, detail="APP_API_URL configuration is required"
)
environment = APP_CONFIG.get("APP_ENV")
if not environment:
raise HTTPException(status_code=500, detail="APP_ENV configuration is required")
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
if not instanceLabel:
raise HTTPException(status_code=500, detail="APP_ENV_LABEL configuration is required")
raise HTTPException(
status_code=500, detail="APP_ENV_LABEL configuration is required"
)
return {
"apiBaseUrl": apiBaseUrl,
"environment": environment,
@ -68,13 +78,17 @@ async def get_environment(request: Request, currentUser: Dict[str, Any] = Depend
# Add other environment variables the frontend might need
}
@router.options("/{fullPath:path}")
@limiter.limit("60/minute")
async def options_route(request: Request, fullPath: str) -> Response:
return Response(status_code=200)
@router.get("/favicon.ico")
@limiter.limit("30/minute")
async def favicon(request: Request) -> FileResponse:
return FileResponse(str(staticFolder / "favicon.ico"), media_type="image/x-icon")
favicon_path = staticFolder / "favicon.ico"
if not favicon_path.exists():
raise HTTPException(status_code=404, detail="Favicon not found")
return FileResponse(str(favicon_path), media_type="image/x-icon")

View file

@ -0,0 +1,655 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.requests import Request
from fastapi.responses import StreamingResponse
from typing import Any, Dict, List, Optional
from datetime import datetime
import logging
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from modules.features.chatBot.database import get_async_db_session
from modules.features.chatBot.service import (
get_or_create_thread_for_user,
)
from modules.datamodels.datamodelUam import User, UserPrivilege
from modules.datamodels.datamodelChatbot import (
ChatMessageRequest,
MessageItem,
ChatMessageResponse,
ThreadSummary,
ThreadListResponse,
ThreadDetail,
RenameThreadRequest,
DeleteResponse,
ToolListResponse,
ToolInfo,
GrantToolRequest,
GrantToolResponse,
RevokeToolRequest,
RevokeToolResponse,
UpdateToolRequest,
UpdateToolResponse,
)
from modules.security.auth import getCurrentUser, limiter
from modules.features.chatBot import service as chat_service
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/chatbot",
tags=["Chatbot"],
responses={404: {"description": "Not found"}},
)
# --- Actual endpoints for chatbot ---
@router.post("/message/stream")
@limiter.limit("30/minute")
async def post_chat_message_stream(
*,
request: Request,
message_request: ChatMessageRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> StreamingResponse:
"""
Post a message to a chat thread with streaming progress updates.
Creates a new thread if thread_id is not provided.
Returns Server-Sent Events (SSE) stream with status updates and final response.
"""
try:
# Validate and get tools for the request
tool_ids = await chat_service.validate_and_get_tools_for_request(
user_id=currentUser.id,
requested_tool_ids=message_request.tools,
session=session,
)
# Get or create thread using helper function
thread_id = await get_or_create_thread_for_user(
thread_id=message_request.thread_id,
user=currentUser,
session=session,
thread_name=message_request.message[:100],
refresh_date_updated=True,
)
logger.info(
f"User {currentUser.id} posted streaming message to thread {thread_id}"
)
return StreamingResponse(
chat_service.post_message_stream(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
tool_ids=tool_ids,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
except PermissionError as e:
logger.error(f"Permission error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.post("/message", response_model=ChatMessageResponse)
@limiter.limit("30/minute")
async def post_chat_message(
*,
request: Request,
message_request: ChatMessageRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ChatMessageResponse:
"""
Post a message to a chat thread and get assistant response (non-streaming).
Creates a new thread if thread_id is not provided.
For streaming updates, use the /message/stream endpoint instead.
"""
try:
# Validate and get tools for the request
tool_ids = await chat_service.validate_and_get_tools_for_request(
user_id=currentUser.id,
requested_tool_ids=message_request.tools,
session=session,
)
# Get or create thread using helper function
thread_id = await get_or_create_thread_for_user(
thread_id=message_request.thread_id,
user=currentUser,
session=session,
thread_name=message_request.message[:100],
refresh_date_updated=True,
)
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
response = await chat_service.post_message(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
tool_ids=tool_ids,
)
return response
except PermissionError as e:
logger.error(f"Permission error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/threads", response_model=ThreadListResponse)
@limiter.limit("30/minute")
async def get_all_threads(
*,
request: Request,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ThreadListResponse:
"""
Get all chat threads for the current user.
"""
try:
# Get all threads for the current user
threads = await chat_service.get_all_threads_for_user(
user=currentUser, session=session
)
logger.info(f"User {currentUser.id} retrieved {len(threads)} threads")
return ThreadListResponse(threads=threads)
except Exception as e:
logger.error(
f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
@limiter.limit("30/minute")
async def get_thread_by_id(
*,
request: Request,
thread_id: str,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ThreadDetail:
"""
Get a specific chat thread with all its messages from LangGraph checkpointer.
"""
try:
thread_detail = await chat_service.get_thread_detail_for_user(
thread_id=thread_id,
user=currentUser,
session=session,
)
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
return thread_detail
except ValueError as e:
logger.error(f"Thread not found: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e) or "Thread not found",
)
except PermissionError as e:
logger.error(f"Permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.patch("/threads/{thread_id}", response_model=DeleteResponse)
@limiter.limit("30/minute")
async def rename_thread(
*,
request: Request,
thread_id: str,
rename_request: RenameThreadRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> DeleteResponse:
"""
Rename a chat thread.
"""
try:
await chat_service.update_thread_name(
thread_id=thread_id,
user=currentUser,
new_thread_name=rename_request.new_name,
session=session,
)
logger.info(
f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'"
)
return DeleteResponse(
message=f"Thread {thread_id} successfully renamed",
thread_id=thread_id,
)
except ValueError as e:
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e) or "Thread not found or permission denied",
)
except Exception as e:
logger.error(
f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.delete("/threads/{thread_id}", response_model=DeleteResponse)
@limiter.limit("10/minute")
async def delete_thread(
*,
request: Request,
thread_id: str,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> DeleteResponse:
"""
Delete a chat thread and all its associated data from both LangGraph and database.
"""
try:
await chat_service.delete_thread_for_user(
thread_id=thread_id,
user=currentUser,
session=session,
)
logger.info(f"User {currentUser.id} deleted thread {thread_id}")
return DeleteResponse(
message=f"Thread {thread_id} successfully deleted",
thread_id=thread_id,
)
except ValueError as e:
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e) or "Thread not found or permission denied",
)
except PermissionError as e:
logger.error(f"Permission denied: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e) or "Permission denied",
)
except Exception as e:
logger.error(
f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
# Tool Management Endpoints
@router.get("/tools", response_model=ToolListResponse)
@limiter.limit("30/minute")
async def get_all_tools(
*,
request: Request,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ToolListResponse:
"""
Get all available chatbot tools.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can view tools",
)
# Get all tools from service
tools_data = await chat_service.get_all_tools(session=session)
# Convert to ToolInfo objects
tools = [ToolInfo(**tool) for tool in tools_data]
logger.info(f"User {currentUser.id} retrieved {len(tools)} tools")
return ToolListResponse(tools=tools)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.post("/tools/grant", response_model=GrantToolResponse)
@limiter.limit("10/minute")
async def grant_tool_to_user(
*,
request: Request,
grant_request: GrantToolRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> GrantToolResponse:
"""
Grant a tool to a user.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can grant tools",
)
# Grant the tool
await chat_service.grant_tool_to_user(
user_id=grant_request.user_id,
tool_id=grant_request.tool_id,
session=session,
)
logger.info(
f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}"
)
return GrantToolResponse(
message=f"Tool successfully granted to user {grant_request.user_id}",
user_id=grant_request.user_id,
tool_id=grant_request.tool_id,
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e) or "Invalid request",
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.delete("/tools/revoke", response_model=RevokeToolResponse)
@limiter.limit("10/minute")
async def revoke_tool_from_user(
*,
request: Request,
revoke_request: RevokeToolRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> RevokeToolResponse:
"""
Revoke a tool from a user.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can revoke tools",
)
# Revoke the tool
await chat_service.revoke_tool_from_user(
user_id=revoke_request.user_id,
tool_id=revoke_request.tool_id,
session=session,
)
logger.info(
f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}"
)
return RevokeToolResponse(
message=f"Tool successfully revoked from user {revoke_request.user_id}",
user_id=revoke_request.user_id,
tool_id=revoke_request.tool_id,
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e) or "Invalid request",
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse)
@limiter.limit("10/minute")
async def update_tool(
*,
request: Request,
tool_id: str,
update_request: UpdateToolRequest,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> UpdateToolResponse:
"""
Update a tool's label and/or description.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can update tools",
)
# Update the tool
updated_fields = await chat_service.update_tool(
tool_id=tool_id,
label=update_request.label,
description=update_request.description,
session=session,
)
logger.info(
f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}"
)
return UpdateToolResponse(
message="Tool successfully updated",
tool_id=tool_id,
updated_fields=updated_fields,
)
except ValueError as e:
logger.error(f"Validation error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e) or "Invalid request",
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/tools/user/{user_id}", response_model=ToolListResponse)
@limiter.limit("30/minute")
async def get_tools_for_specific_user(
*,
request: Request,
user_id: str,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ToolListResponse:
"""
Get all tools granted to a specific user.
Only accessible to system administrators.
"""
try:
# Check SYSADMIN permission
if currentUser.privilege != UserPrivilege.SYSADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only system administrators can view user tools",
)
# Get tools for the specified user
tools_data = await chat_service.get_tools_for_user(
user_id=user_id, session=session
)
# Convert to ToolInfo objects
tools = [ToolInfo(**tool) for tool in tools_data]
logger.info(
f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}"
)
return ToolListResponse(tools=tools)
except HTTPException:
raise
except Exception as e:
logger.error(
f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}",
)
@router.get("/tools/me", response_model=ToolListResponse)
@limiter.limit("30/minute")
async def get_my_tools(
*,
request: Request,
currentUser: User = Depends(getCurrentUser),
session: AsyncSession = Depends(get_async_db_session),
) -> ToolListResponse:
"""
Get all tools the current user has access to.
"""
try:
# Get tools for the current user
tools_data = await chat_service.get_tools_for_user(
user_id=currentUser.id, session=session
)
# Convert to ToolInfo objects
tools = [ToolInfo(**tool) for tool in tools_data]
logger.info(
f"User {currentUser.id} retrieved {len(tools)} tools for themselves"
)
return ToolListResponse(tools=tools)
except Exception as e:
logger.error(
f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
)

View file

@ -2,52 +2,55 @@
Shared utilities for model attributes and labels.
"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Dict, Any, List, Type, Optional, Union
import inspect
import importlib
import os
from datetime import datetime
class ModelMixin:
"""Mixin class that provides serialization methods for Pydantic models."""
def to_dict(self) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
Handles both Pydantic v1 and v2.
All timestamp fields remain as float values.
Returns:
Dict[str, Any]: Dictionary representation of the model
"""
# Get the raw dictionary
if hasattr(self, 'model_dump'):
if hasattr(self, "model_dump"):
data: Dict[str, Any] = self.model_dump() # Pydantic v2
else:
data: Dict[str, Any] = self.dict() # Pydantic v1
# All fields (including timestamps) remain in their original format
# No conversions needed - timestamps are already float
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelMixin':
def from_dict(cls, data: Dict[str, Any]) -> "ModelMixin":
"""
Create a Pydantic model instance from a dictionary.
Args:
data: Dictionary containing the model data
Returns:
ModelMixin: New instance of the model class
"""
return cls(**data)
# Define the AttributeDefinition class here instead of importing it
class AttributeDefinition(BaseModel, ModelMixin):
"""Definition of a model attribute with its metadata."""
name: str
type: str
label: str
@ -64,41 +67,47 @@ class AttributeDefinition(BaseModel, ModelMixin):
order: int = 0
placeholder: Optional[str] = None
# Global registry for model labels
MODEL_LABELS: Dict[str, Dict[str, Dict[str, str]]] = {}
def to_dict(model: BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
Handles both Pydantic v1 and v2.
Args:
model: The Pydantic model instance to convert
Returns:
Dict[str, Any]: Dictionary representation of the model
"""
if hasattr(model, 'model_dump'):
if hasattr(model, "model_dump"):
return model.model_dump() # Pydantic v2
return model.dict() # Pydantic v1
def from_dict(model_class: Type[BaseModel], data: Dict[str, Any]) -> BaseModel:
"""
Create a Pydantic model instance from a dictionary.
Args:
model_class: The Pydantic model class to instantiate
data: Dictionary containing the model data
Returns:
BaseModel: New instance of the model class
"""
return model_class(**data)
def register_model_labels(model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]):
def register_model_labels(
model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]
):
"""
Register labels for a model's attributes and the model itself.
Args:
model_name: Name of the model class
model_label: Dictionary mapping language codes to model labels
@ -106,38 +115,37 @@ def register_model_labels(model_name: str, model_label: Dict[str, str], labels:
labels: Dictionary mapping attribute names to their translations
e.g. {"name": {"en": "Name", "fr": "Nom"}}
"""
MODEL_LABELS[model_name] = {
"model": model_label,
"attributes": labels
}
MODEL_LABELS[model_name] = {"model": model_label, "attributes": labels}
def get_model_labels(model_name: str, language: str = "en") -> Dict[str, str]:
"""
Get labels for a model's attributes in the specified language.
Args:
model_name: Name of the model class
language: Language code (default: "en")
Returns:
Dictionary mapping attribute names to their labels in the specified language
"""
model_data = MODEL_LABELS.get(model_name, {})
attribute_labels = model_data.get("attributes", {})
return {
attr: translations.get(language, translations.get("en", attr))
for attr, translations in attribute_labels.items()
}
def get_model_label(model_name: str, language: str = "en") -> str:
"""
Get the label for a model in the specified language.
Args:
model_name: Name of the model class
language: Language code (default: "en")
Returns:
Model label in the specified language, or model name if no label exists
"""
@ -145,173 +153,228 @@ def get_model_label(model_name: str, language: str = "en") -> str:
model_label = model_data.get("model", {})
return model_label.get(language, model_label.get("en", model_name))
def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguage: str = "en") -> Dict[str, Any]:
def getModelAttributeDefinitions(
modelClass: Type[BaseModel] = None, userLanguage: str = "en"
) -> Dict[str, Any]:
"""
Get attribute definitions for a model class.
Args:
modelClass: The model class to get attributes for
userLanguage: Language code for translations (default: "en")
Returns:
Dictionary containing model label and attribute definitions
"""
if not modelClass:
return {}
attributes = []
model_name = modelClass.__name__
labels = get_model_labels(model_name, userLanguage)
model_label = get_model_label(model_name, userLanguage)
# Handle both Pydantic v1 and v2
if hasattr(modelClass, 'model_fields'): # Pydantic v2
if hasattr(modelClass, "model_fields"): # Pydantic v2
fields = modelClass.model_fields
for name, field in fields.items():
# Extract frontend metadata from field info
field_info = field.field_info if hasattr(field, 'field_info') else None
field_info = field.field_info if hasattr(field, "field_info") else None
# Check both direct attributes and extra field for frontend metadata
frontend_type = None
frontend_readonly = False
frontend_required = field.is_required()
frontend_options = None
if field_info:
# Try direct attributes first
frontend_type = getattr(field_info, 'frontend_type', None)
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
frontend_options = getattr(field_info, 'frontend_options', None)
frontend_type = getattr(field_info, "frontend_type", None)
frontend_readonly = getattr(field_info, "frontend_readonly", False)
frontend_required = getattr(
field_info, "frontend_required", frontend_required
)
frontend_options = getattr(field_info, "frontend_options", None)
# If not found, check extra field
if hasattr(field_info, 'extra') and field_info.extra:
if hasattr(field_info, "extra") and field_info.extra:
if frontend_type is None:
frontend_type = field_info.extra.get('frontend_type')
frontend_type = field_info.extra.get("frontend_type")
if not frontend_readonly:
frontend_readonly = field_info.extra.get('frontend_readonly', False)
if frontend_required == field.is_required(): # Only override if we didn't get it from direct attribute
frontend_required = field_info.extra.get('frontend_required', frontend_required)
frontend_readonly = field_info.extra.get(
"frontend_readonly", False
)
if (
frontend_required == field.is_required()
): # Only override if we didn't get it from direct attribute
frontend_required = field_info.extra.get(
"frontend_required", frontend_required
)
if frontend_options is None:
frontend_options = field_info.extra.get('frontend_options')
frontend_options = field_info.extra.get("frontend_options")
# Use frontend type if available, otherwise fall back to Python type
field_type = frontend_type if frontend_type else (field.annotation.__name__ if hasattr(field.annotation, "__name__") else str(field.annotation))
attributes.append({
"name": name,
"type": field_type,
"required": frontend_required,
"description": field.description if hasattr(field, "description") else "",
"label": labels.get(name, name),
"placeholder": f"Please enter {labels.get(name, name)}",
"editable": not frontend_readonly,
"visible": True,
"order": len(attributes),
"readonly": frontend_readonly,
"options": frontend_options
})
field_type = (
frontend_type
if frontend_type
else (
field.annotation.__name__
if hasattr(field.annotation, "__name__")
else str(field.annotation)
)
)
attributes.append(
{
"name": name,
"type": field_type,
"required": frontend_required,
"description": field.description
if hasattr(field, "description")
else "",
"label": labels.get(name, name),
"placeholder": f"Please enter {labels.get(name, name)}",
"editable": not frontend_readonly,
"visible": True,
"order": len(attributes),
"readonly": frontend_readonly,
"options": frontend_options,
}
)
else: # Pydantic v1
fields = modelClass.__fields__
for name, field in fields.items():
# Extract frontend metadata from field info
field_info = field.field_info if hasattr(field, 'field_info') else None
field_info = field.field_info if hasattr(field, "field_info") else None
# Check both direct attributes and extra field for frontend metadata
frontend_type = None
frontend_readonly = False
frontend_required = field.required
frontend_options = None
if field_info:
# Try direct attributes first
frontend_type = getattr(field_info, 'frontend_type', None)
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
frontend_options = getattr(field_info, 'frontend_options', None)
frontend_type = getattr(field_info, "frontend_type", None)
frontend_readonly = getattr(field_info, "frontend_readonly", False)
frontend_required = getattr(
field_info, "frontend_required", frontend_required
)
frontend_options = getattr(field_info, "frontend_options", None)
# If not found, check extra field
if hasattr(field_info, 'extra') and field_info.extra:
if hasattr(field_info, "extra") and field_info.extra:
if frontend_type is None:
frontend_type = field_info.extra.get('frontend_type')
frontend_type = field_info.extra.get("frontend_type")
if not frontend_readonly:
frontend_readonly = field_info.extra.get('frontend_readonly', False)
if frontend_required == field.required: # Only override if we didn't get it from direct attribute
frontend_required = field_info.extra.get('frontend_required', frontend_required)
frontend_readonly = field_info.extra.get(
"frontend_readonly", False
)
if (
frontend_required == field.required
): # Only override if we didn't get it from direct attribute
frontend_required = field_info.extra.get(
"frontend_required", frontend_required
)
if frontend_options is None:
frontend_options = field_info.extra.get('frontend_options')
frontend_options = field_info.extra.get("frontend_options")
# Use frontend type if available, otherwise fall back to Python type
field_type = frontend_type if frontend_type else (field.type_.__name__ if hasattr(field.type_, "__name__") else str(field.type_))
attributes.append({
"name": name,
"type": field_type,
"required": frontend_required,
"description": field.field_info.description if hasattr(field.field_info, "description") else "",
"label": labels.get(name, name),
"placeholder": f"Please enter {labels.get(name, name)}",
"editable": not frontend_readonly,
"visible": True,
"order": len(attributes),
"readonly": frontend_readonly,
"options": frontend_options
})
return {
"model": model_label,
"attributes": attributes
}
field_type = (
frontend_type
if frontend_type
else (
field.type_.__name__
if hasattr(field.type_, "__name__")
else str(field.type_)
)
)
attributes.append(
{
"name": name,
"type": field_type,
"required": frontend_required,
"description": field.field_info.description
if hasattr(field.field_info, "description")
else "",
"label": labels.get(name, name),
"placeholder": f"Please enter {labels.get(name, name)}",
"editable": not frontend_readonly,
"visible": True,
"order": len(attributes),
"readonly": frontend_readonly,
"options": frontend_options,
}
)
return {"model": model_label, "attributes": attributes}
def getModelClasses() -> Dict[str, Type[BaseModel]]:
"""
Dynamically get all model classes from all model modules.
Returns:
Dict[str, Type[BaseModel]]: Dictionary of model class names to their classes
"""
modelClasses = {}
# Get the interfaces directory path
interfaces_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'interfaces')
interfaces_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "interfaces"
)
# Find all model files in interfaces directory
for fileName in os.listdir(interfaces_dir):
if fileName.endswith('Model.py'):
if fileName.endswith("Model.py"):
# Convert fileName to module name (e.g., gatewayModel.py -> gatewayModel)
module_name = fileName[:-3]
# Import the module dynamically
module = importlib.import_module(f'modules.interfaces.{module_name}')
module = importlib.import_module(f"modules.interfaces.{module_name}")
# Get all classes from the module
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
if (
inspect.isclass(obj)
and issubclass(obj, BaseModel)
and obj != BaseModel
):
modelClasses[name] = obj
# Also get models from datamodels directory
datamodels_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'datamodels')
datamodels_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "datamodels"
)
# Find all model files in datamodels directory
for fileName in os.listdir(datamodels_dir):
if fileName.startswith('datamodel') and fileName.endswith('.py'):
if fileName.startswith("datamodel") and fileName.endswith(".py"):
# Convert fileName to module name (e.g., datamodelUtils.py -> datamodelUtils)
module_name = fileName[:-3]
# Import the module dynamically
module = importlib.import_module(f'modules.datamodels.{module_name}')
module = importlib.import_module(f"modules.datamodels.{module_name}")
# Get all classes from the module
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
if (
inspect.isclass(obj)
and issubclass(obj, BaseModel)
and obj != BaseModel
):
modelClasses[name] = obj
return modelClasses
class AttributeResponse(BaseModel):
"""Response model for entity attributes"""
attributes: List[AttributeDefinition]
class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra={
"example": {
"attributes": [
{
@ -322,8 +385,9 @@ class AttributeResponse(BaseModel):
"placeholder": "Please enter username",
"editable": True,
"visible": True,
"order": 0
"order": 0,
}
]
}
}
}
)

View file

@ -1,10 +1,10 @@
## Web Framework & API
fastapi==0.104.1
fastapi==0.115.0 # Upgraded for Pydantic v2 compatibility
websockets==12.0
uvicorn==0.23.2
python-multipart==0.0.6
httpx==0.25.0
pydantic==1.10.13 # Ältere Version ohne Rust-Abhängigkeit
httpx>=0.25.2
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
email-validator==2.0.0 # Required by Pydantic for email validation
slowapi==0.1.8 # For rate limiting
@ -109,3 +109,14 @@ xyzservices>=2021.09.1
# PostgreSQL connector dependencies
psycopg2-binary==2.9.9
## LangChain & LangGraph
langchain==0.3.27
langgraph==0.6.8
langchain-core==0.3.77
langchain-anthropic==0.3.1 # For Claude models
psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer)
psycopg-pool==3.2.1 # Connection pooling for PostgreSQL
langgraph-checkpoint-postgres==2.0.24
greenlet==3.2.4

View file

@ -0,0 +1,198 @@
"""Pytest tests for the tool registry.
This module tests that the tool registry correctly discovers and catalogs
all tools in the chatbotTools directory.
"""
import logging
import pytest
from modules.features.chatBot.utils.toolRegistry import (
ToolMetadata,
ToolRegistry,
get_registry,
reinitialize_registry,
)
from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)
class TestToolRegistry:
"""Test suite for ToolRegistry class."""
@pytest.fixture
def registry(self) -> ToolRegistry:
"""Provide a fresh registry instance for each test."""
return reinitialize_registry()
def test_registry_initialization(self, registry: ToolRegistry) -> None:
"""Test that registry initializes correctly."""
assert registry.is_initialized
assert isinstance(registry._tools, dict)
def test_get_all_tools(self, registry: ToolRegistry) -> None:
"""Test getting all registered tools."""
all_tools = registry.get_all_tools()
assert isinstance(all_tools, list)
assert len(all_tools) > 0
assert all(isinstance(tool, ToolMetadata) for tool in all_tools)
# Log all discovered tools
logger.info(f"Found {len(all_tools)} tools in registry:")
for tool in all_tools:
logger.info(f"\n{tool}")
def test_tool_metadata_structure(self, registry: ToolRegistry) -> None:
"""Test that tool metadata has correct structure."""
all_tools = registry.get_all_tools()
for tool in all_tools:
assert isinstance(tool.tool_id, str)
assert isinstance(tool.name, str)
assert isinstance(tool.category, str)
assert tool.category in ["shared", "customer"]
assert isinstance(tool.description, str)
assert isinstance(tool.tool_instance, BaseTool)
assert isinstance(tool.module_path, str)
def test_list_tool_ids(self, registry: ToolRegistry) -> None:
"""Test listing all tool IDs."""
tool_ids = registry.list_tool_ids()
assert isinstance(tool_ids, list)
assert len(tool_ids) > 0
assert all(isinstance(tool_id, str) for tool_id in tool_ids)
# Check that tool IDs follow expected format
for tool_id in tool_ids:
assert "." in tool_id
category, name = tool_id.split(".", 1)
assert category in ["shared", "customer"]
def test_get_specific_tool(self, registry: ToolRegistry) -> None:
"""Test retrieving a specific tool by ID."""
# Get all tool IDs first
tool_ids = registry.list_tool_ids()
if tool_ids:
# Test with first available tool
test_tool_id = tool_ids[0]
tool_metadata = registry.get_tool(tool_id=test_tool_id)
assert tool_metadata is not None
assert isinstance(tool_metadata, ToolMetadata)
assert tool_metadata.tool_id == test_tool_id
def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None:
"""Test retrieving a tool that doesn't exist."""
tool_metadata = registry.get_tool(tool_id="nonexistent.tool")
assert tool_metadata is None
def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None:
"""Test getting all shared tools."""
shared_tools = registry.get_tools_by_category(category="shared")
assert isinstance(shared_tools, list)
assert all(tool.category == "shared" for tool in shared_tools)
def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None:
"""Test getting all customer tools."""
customer_tools = registry.get_tools_by_category(category="customer")
assert isinstance(customer_tools, list)
assert all(tool.category == "customer" for tool in customer_tools)
def test_get_tool_instances(self, registry: ToolRegistry) -> None:
"""Test getting tool instances by IDs."""
tool_ids = registry.list_tool_ids()
if len(tool_ids) >= 2:
# Test with first two tools
test_ids = tool_ids[:2]
instances = registry.get_tool_instances(tool_ids=test_ids)
assert isinstance(instances, list)
assert len(instances) == 2
assert all(isinstance(inst, BaseTool) for inst in instances)
def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None:
"""Test getting tool instances with some invalid IDs."""
tool_ids = registry.list_tool_ids()
if tool_ids:
# Mix valid and invalid IDs
test_ids = [tool_ids[0], "invalid.tool"]
instances = registry.get_tool_instances(tool_ids=test_ids)
# Should only return the valid one
assert len(instances) == 1
assert isinstance(instances[0], BaseTool)
def test_global_registry_singleton(self) -> None:
"""Test that get_registry returns same instance."""
registry1 = get_registry()
registry2 = get_registry()
assert registry1 is registry2
def test_reinitialize_registry(self) -> None:
"""Test that reinitialize creates new instance."""
registry1 = get_registry()
registry2 = reinitialize_registry()
# Should be different instances after reinitialize
assert registry1 is not registry2
assert registry2.is_initialized
class TestToolDiscovery:
"""Test suite for tool discovery functionality."""
def test_discovers_at_least_one_tool(self) -> None:
"""Test that at least one tool is discovered."""
registry = get_registry()
tool_ids = registry.list_tool_ids()
# At least one tool should be successfully loaded
assert len(tool_ids) >= 1, "Expected at least one tool to be discovered"
def test_query_althaus_database_if_available(self) -> None:
"""Test query_althaus_database tool if it was successfully loaded."""
registry = get_registry()
tool = registry.get_tool(tool_id="customer.query_althaus_database")
if tool is not None:
assert tool.name == "query_althaus_database"
assert tool.category == "customer"
assert "database" in tool.description.lower()
else:
# Tool may not have loaded due to import errors - log warning
import logging
logging.warning(
"customer.query_althaus_database tool not found - "
"may have failed to import"
)
def test_tavily_search_if_available(self) -> None:
"""Test tavily_search tool if it was successfully loaded."""
registry = get_registry()
tool = registry.get_tool(tool_id="shared.tavily_search")
if tool is not None:
assert tool.name == "tavily_search"
assert tool.category == "shared"
assert "search" in tool.description.lower()
else:
# Tool may not have loaded due to import errors - log warning
import logging
logging.warning(
"shared.tavily_search tool not found - may have failed to import"
)
def test_tool_ids_have_correct_format(self) -> None:
"""Test that all discovered tool IDs follow the expected format."""
registry = get_registry()
tool_ids = registry.list_tool_ids()
for tool_id in tool_ids:
# All tool IDs should have format: category.toolname
assert "." in tool_id, f"Tool ID {tool_id} missing category separator"
category, name = tool_id.split(".", 1)
assert category in [
"shared",
"customer",
], f"Tool {tool_id} has invalid category: {category}"
assert len(name) > 0, f"Tool {tool_id} has empty name"