commit
7b1a26590c
26 changed files with 5672 additions and 838 deletions
270
app.py
270
app.py
|
|
@ -1,10 +1,22 @@
|
|||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
from urllib.parse import quote_plus
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from modules.features.chatBot.database import init_models as init_chatbot_models
|
||||
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = "12"
|
||||
|
||||
# Fix for Windows asyncio compatibility with psycopg
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, Body, status, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
||||
from fastapi.security import HTTPBearer
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
|
@ -20,32 +32,36 @@ class DailyRotatingFileHandler(RotatingFileHandler):
|
|||
A rotating file handler that automatically switches to a new file when the date changes.
|
||||
The log file name includes the current date and switches at midnight.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs):
|
||||
|
||||
def __init__(
|
||||
self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs
|
||||
):
|
||||
self.log_dir = log_dir
|
||||
self.filename_prefix = filename_prefix
|
||||
self.current_date = None
|
||||
self.current_file = None
|
||||
|
||||
|
||||
# Initialize with today's file
|
||||
self._update_file_if_needed()
|
||||
|
||||
|
||||
# Call parent constructor with current file
|
||||
super().__init__(self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs)
|
||||
|
||||
super().__init__(
|
||||
self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs
|
||||
)
|
||||
|
||||
def _update_file_if_needed(self):
|
||||
"""Update the log file if the date has changed"""
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
if self.current_date != today:
|
||||
self.current_date = today
|
||||
new_file = os.path.join(self.log_dir, f"{self.filename_prefix}_{today}.log")
|
||||
|
||||
|
||||
if self.current_file != new_file:
|
||||
self.current_file = new_file
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def emit(self, record):
|
||||
"""Emit a log record, switching files if date has changed"""
|
||||
# Check if we need to switch to a new file
|
||||
|
|
@ -54,16 +70,17 @@ class DailyRotatingFileHandler(RotatingFileHandler):
|
|||
if self.stream:
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
|
||||
|
||||
# Update the baseFilename for the parent class
|
||||
self.baseFilename = self.current_file
|
||||
# Reopen the stream
|
||||
if not self.delay:
|
||||
self.stream = self._open()
|
||||
|
||||
|
||||
# Call parent emit method
|
||||
super().emit(record)
|
||||
|
||||
|
||||
def initLogging():
|
||||
"""Initialize logging with configuration from APP_CONFIG"""
|
||||
# Get log level from config (default to INFO if not found)
|
||||
|
|
@ -76,33 +93,39 @@ def initLogging():
|
|||
# If relative path, make it relative to the gateway directory
|
||||
gatewayDir = os.path.dirname(os.path.abspath(__file__))
|
||||
logDir = os.path.join(gatewayDir, logDir)
|
||||
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(logDir, exist_ok=True)
|
||||
|
||||
# Create formatters - using single line format
|
||||
consoleFormatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
|
||||
# File formatter with more detailed error information but still single line
|
||||
fileFormatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s - %(pathname)s:%(lineno)d - %(funcName)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
# Add filter to exclude Chrome DevTools requests
|
||||
class ChromeDevToolsFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
return not (isinstance(record.msg, str) and
|
||||
('.well-known/appspecific/com.chrome.devtools.json' in record.msg or
|
||||
'Request: /index.html' in record.msg))
|
||||
return not (
|
||||
isinstance(record.msg, str)
|
||||
and (
|
||||
".well-known/appspecific/com.chrome.devtools.json" in record.msg
|
||||
or "Request: /index.html" in record.msg
|
||||
)
|
||||
)
|
||||
|
||||
# Add filter to exclude all httpcore loggers (including sub-loggers)
|
||||
class HttpcoreStarFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
return not (record.name == 'httpcore' or record.name.startswith('httpcore.'))
|
||||
return not (
|
||||
record.name == "httpcore" or record.name.startswith("httpcore.")
|
||||
)
|
||||
|
||||
# Add filter to exclude HTTP debug messages
|
||||
class HTTPDebugFilter(logging.Filter):
|
||||
|
|
@ -110,14 +133,14 @@ def initLogging():
|
|||
if isinstance(record.msg, str):
|
||||
# Filter out HTTP debug messages
|
||||
http_debug_patterns = [
|
||||
'receive_response_body.started',
|
||||
'receive_response_body.complete',
|
||||
'response_closed.started',
|
||||
'_send_single_request',
|
||||
'httpcore.http11',
|
||||
'httpx._client',
|
||||
'HTTP Request',
|
||||
'multipart.multipart'
|
||||
"receive_response_body.started",
|
||||
"receive_response_body.complete",
|
||||
"response_closed.started",
|
||||
"_send_single_request",
|
||||
"httpcore.http11",
|
||||
"httpx._client",
|
||||
"HTTP Request",
|
||||
"multipart.multipart",
|
||||
]
|
||||
return not any(pattern in record.msg for pattern in http_debug_patterns)
|
||||
return True
|
||||
|
|
@ -129,13 +152,21 @@ def initLogging():
|
|||
# Remove only emojis, preserve other Unicode characters like quotes
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
# Remove emoji characters specifically
|
||||
record.msg = ''.join(char for char in record.msg if unicodedata.category(char) != 'So' or not (0x1F600 <= ord(char) <= 0x1F64F or 0x1F300 <= ord(char) <= 0x1F5FF or 0x1F680 <= ord(char) <= 0x1F6FF or 0x1F1E0 <= ord(char) <= 0x1F1FF or 0x2600 <= ord(char) <= 0x26FF or 0x2700 <= ord(char) <= 0x27BF))
|
||||
# Additionally strip characters not representable in Windows cp1252 (e.g., arrows)
|
||||
try:
|
||||
record.msg.encode('cp1252', errors='strict')
|
||||
except UnicodeEncodeError:
|
||||
record.msg = record.msg.encode('cp1252', errors='ignore').decode('cp1252', errors='ignore')
|
||||
record.msg = "".join(
|
||||
char
|
||||
for char in record.msg
|
||||
if unicodedata.category(char) != "So"
|
||||
or not (
|
||||
0x1F600 <= ord(char) <= 0x1F64F
|
||||
or 0x1F300 <= ord(char) <= 0x1F5FF
|
||||
or 0x1F680 <= ord(char) <= 0x1F6FF
|
||||
or 0x1F1E0 <= ord(char) <= 0x1F1FF
|
||||
or 0x2600 <= ord(char) <= 0x26FF
|
||||
or 0x2700 <= ord(char) <= 0x27BF
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
# Configure handlers based on config
|
||||
|
|
@ -154,14 +185,16 @@ def initLogging():
|
|||
# Add file handler if enabled
|
||||
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
|
||||
# Create daily application log file with automatic date switching
|
||||
rotationSize = int(APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)) # Default: 10MB
|
||||
rotationSize = int(
|
||||
APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)
|
||||
) # Default: 10MB
|
||||
backupCount = int(APP_CONFIG.get("APP_LOGGING_BACKUP_COUNT", 5))
|
||||
|
||||
|
||||
fileHandler = DailyRotatingFileHandler(
|
||||
log_dir=logDir,
|
||||
filename_prefix="log_app",
|
||||
max_bytes=rotationSize,
|
||||
backup_count=backupCount
|
||||
backup_count=backupCount,
|
||||
)
|
||||
fileHandler.setFormatter(fileFormatter)
|
||||
fileHandler.addFilter(ChromeDevToolsFilter())
|
||||
|
|
@ -176,11 +209,18 @@ def initLogging():
|
|||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
handlers=handlers,
|
||||
force=True # Force reconfiguration of the root logger
|
||||
force=True, # Force reconfiguration of the root logger
|
||||
)
|
||||
|
||||
# Silence noisy third-party libraries - use the same level as the root logger
|
||||
noisyLoggers = ["httpx", "httpcore", "urllib3", "asyncio", "fastapi.security.oauth2", "msal"]
|
||||
noisyLoggers = [
|
||||
"httpx",
|
||||
"httpcore",
|
||||
"urllib3",
|
||||
"asyncio",
|
||||
"fastapi.security.oauth2",
|
||||
"msal",
|
||||
]
|
||||
for loggerName in noisyLoggers:
|
||||
logging.getLogger(loggerName).setLevel(logging.WARNING)
|
||||
|
||||
|
|
@ -188,38 +228,142 @@ def initLogging():
|
|||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Logging initialized with level {logLevelName}")
|
||||
logger.info(f"Log directory: {logDir}")
|
||||
|
||||
if APP_CONFIG.get('APP_LOGGING_FILE_ENABLED', True):
|
||||
|
||||
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
appLogFile = os.path.join(logDir, f"log_app_{today}.log")
|
||||
logger.info(f"Application log file: {appLogFile} (auto-switches daily)")
|
||||
else:
|
||||
logger.info("Application log file: disabled")
|
||||
|
||||
logger.info(f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}")
|
||||
|
||||
logger.info(
|
||||
f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}"
|
||||
)
|
||||
|
||||
|
||||
def make_sqlalchemy_db_url() -> str:
|
||||
host = APP_CONFIG.get("SQLALCHEMY_DB_HOST", "localhost")
|
||||
port = APP_CONFIG.get("SQLALCHEMY_DB_PORT", "5432")
|
||||
db = APP_CONFIG.get("SQLALCHEMY_DB_DATABASE", "project_gateway")
|
||||
user = APP_CONFIG.get("SQLALCHEMY_DB_USER", "postgres")
|
||||
pwd = quote_plus(APP_CONFIG.get("SQLALCHEMY_DB_PASSWORD_SECRET", ""))
|
||||
return f"postgresql+psycopg://{user}:{pwd}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
# Initialize logging
|
||||
initLogging()
|
||||
logger = logging.getLogger(__name__)
|
||||
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||
|
||||
|
||||
# Define lifespan context manager for application startup/shutdown events
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Application is starting up")
|
||||
|
||||
# --- Init SQLAlchemy ---
|
||||
|
||||
engine = create_async_engine(
|
||||
make_sqlalchemy_db_url(), pool_pre_ping=True, echo=False
|
||||
)
|
||||
SessionLocal = async_sessionmaker(
|
||||
engine, expire_on_commit=False, class_=AsyncSession
|
||||
)
|
||||
app.state.checkpoint_engine = engine
|
||||
app.state.checkpoint_sessionmaker = SessionLocal
|
||||
|
||||
# NOTE: Might need Alembic migrations in the future
|
||||
await init_chatbot_models(engine)
|
||||
|
||||
# --- Sync tools from registry to database ---
|
||||
from modules.features.chatBot.database import sync_tools_from_registry
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await sync_tools_from_registry(session)
|
||||
await session.commit()
|
||||
logger.info("Tools synced from registry to database")
|
||||
|
||||
# --- Initialize LangGraph checkpointer ---
|
||||
|
||||
from modules.features.chatBot.utils.checkpointer import (
|
||||
initialize_checkpointer,
|
||||
close_checkpointer,
|
||||
)
|
||||
|
||||
try:
|
||||
await initialize_checkpointer()
|
||||
logger.info("LangGraph checkpointer initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}")
|
||||
# Continue startup even if checkpointer fails to initialize
|
||||
|
||||
# --- Init Event Manager ---
|
||||
eventManager.start()
|
||||
|
||||
yield
|
||||
|
||||
# --- Cleanup Event Manager ---
|
||||
eventManager.stop()
|
||||
|
||||
# --- Cleanup LangGraph checkpointer ---
|
||||
await close_checkpointer()
|
||||
|
||||
# --- Cleanup SQLAlchemy ---
|
||||
await engine.dispose()
|
||||
|
||||
logger.info("Application has been shut down")
|
||||
|
||||
|
||||
# START APP
|
||||
app = FastAPI(
|
||||
title="PowerOn | Data Platform API",
|
||||
title="PowerOn | Data Platform API",
|
||||
description=f"Backend API for the Multi-Agent Platform by ValueOn AG ({instanceLabel})",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
swagger_ui_init_oauth={
|
||||
"usePkceWithAuthorizationCodeGrant": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Configure OpenAPI security scheme for Swagger UI
|
||||
# This adds the "Authorize" button to the /docs page
|
||||
security_scheme = HTTPBearer()
|
||||
app.openapi_schema = None # Reset schema to regenerate with security
|
||||
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version="1.0.0",
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add security scheme definition
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
"BearerAuth": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
"description": "Enter your JWT token (obtained from login endpoint or browser cookies)",
|
||||
}
|
||||
}
|
||||
|
||||
# Apply security globally to all endpoints
|
||||
# Individual endpoints can override this if needed
|
||||
openapi_schema["security"] = [{"BearerAuth": []}]
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
|
||||
# Parse CORS origins from environment variable
|
||||
def get_allowed_origins():
|
||||
origins_str = APP_CONFIG.get("APP_ALLOWED_ORIGINS", "http://localhost:8080")
|
||||
|
|
@ -228,73 +372,99 @@ def get_allowed_origins():
|
|||
logger.info(f"CORS allowed origins: {origins}")
|
||||
return origins
|
||||
|
||||
|
||||
# CORS configuration using environment variables
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins= get_allowed_origins(),
|
||||
allow_origins=get_allowed_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
max_age=86400 # Increased caching for preflight requests
|
||||
max_age=86400, # Increased caching for preflight requests
|
||||
)
|
||||
|
||||
# CSRF protection middleware
|
||||
from modules.security.csrf import CSRFMiddleware
|
||||
from modules.security.tokenRefreshMiddleware import TokenRefreshMiddleware, ProactiveTokenRefreshMiddleware
|
||||
from modules.security.tokenRefreshMiddleware import (
|
||||
TokenRefreshMiddleware,
|
||||
ProactiveTokenRefreshMiddleware,
|
||||
)
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# Token refresh middleware (silent refresh for expired OAuth tokens)
|
||||
app.add_middleware(TokenRefreshMiddleware, enabled=True)
|
||||
|
||||
# Proactive token refresh middleware (refresh tokens before they expire)
|
||||
app.add_middleware(ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5)
|
||||
app.add_middleware(
|
||||
ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5
|
||||
)
|
||||
|
||||
# Run triggered features
|
||||
import modules.features.init
|
||||
|
||||
# Include all routers
|
||||
from modules.routes.routeAdmin import router as generalRouter
|
||||
|
||||
app.include_router(generalRouter)
|
||||
|
||||
from modules.routes.routeAttributes import router as attributesRouter
|
||||
|
||||
app.include_router(attributesRouter)
|
||||
|
||||
from modules.routes.routeDataMandates import router as mandateRouter
|
||||
|
||||
app.include_router(mandateRouter)
|
||||
|
||||
from modules.routes.routeDataUsers import router as userRouter
|
||||
|
||||
app.include_router(userRouter)
|
||||
|
||||
from modules.routes.routeDataFiles import router as fileRouter
|
||||
|
||||
app.include_router(fileRouter)
|
||||
|
||||
from modules.routes.routeDataNeutralization import router as neutralizationRouter
|
||||
|
||||
app.include_router(neutralizationRouter)
|
||||
|
||||
from modules.routes.routeDataPrompts import router as promptRouter
|
||||
|
||||
app.include_router(promptRouter)
|
||||
|
||||
from modules.routes.routeDataConnections import router as connectionsRouter
|
||||
|
||||
app.include_router(connectionsRouter)
|
||||
|
||||
from modules.routes.routeWorkflows import router as workflowRouter
|
||||
|
||||
app.include_router(workflowRouter)
|
||||
|
||||
from modules.routes.routeChatPlayground import router as chatPlaygroundRouter
|
||||
|
||||
app.include_router(chatPlaygroundRouter)
|
||||
|
||||
from modules.routes.routeSecurityLocal import router as localRouter
|
||||
|
||||
app.include_router(localRouter)
|
||||
|
||||
from modules.routes.routeSecurityMsft import router as msftRouter
|
||||
|
||||
app.include_router(msftRouter)
|
||||
|
||||
from modules.routes.routeSecurityGoogle import router as googleRouter
|
||||
|
||||
app.include_router(googleRouter)
|
||||
|
||||
from modules.routes.routeVoiceGoogle import router as voiceGoogleRouter
|
||||
|
||||
app.include_router(voiceGoogleRouter)
|
||||
|
||||
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
|
||||
app.include_router(adminSecurityRouter)
|
||||
|
||||
app.include_router(adminSecurityRouter)
|
||||
|
||||
from modules.routes.routeChatbot import router as chatbotRouter
|
||||
|
||||
app.include_router(chatbotRouter)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,21 +1,33 @@
|
|||
"""Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class ChatStat(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
|
||||
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: Optional[str] = Field(
|
||||
None, description="Foreign key to workflow (for workflow stats)"
|
||||
)
|
||||
messageId: Optional[str] = Field(
|
||||
None, description="Foreign key to message (for message stats)"
|
||||
)
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Processing time in seconds"
|
||||
)
|
||||
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
|
||||
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
||||
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
|
||||
successRate: Optional[float] = Field(None, description="Success rate of operations")
|
||||
errorCount: Optional[int] = Field(None, description="Number of errors encountered")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatStat",
|
||||
{"en": "Chat Statistics", "fr": "Statistiques de chat"},
|
||||
|
|
@ -32,15 +44,27 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatLog(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
message: str = Field(description="Log message")
|
||||
type: str = Field(description="Log type (info, warning, error, etc.)")
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)")
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the log entry was created (UTC timestamp in seconds)",
|
||||
)
|
||||
status: Optional[str] = Field(None, description="Status of the log entry")
|
||||
progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)")
|
||||
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
|
||||
progress: Optional[float] = Field(
|
||||
None, description="Progress indicator (0.0 to 1.0)"
|
||||
)
|
||||
performance: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Performance metrics"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatLog",
|
||||
{"en": "Chat Log", "fr": "Journal de chat"},
|
||||
|
|
@ -56,8 +80,11 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatDocument(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
messageId: str = Field(description="Foreign key to message")
|
||||
fileId: str = Field(description="Foreign key to file")
|
||||
fileName: str = Field(description="Name of the file")
|
||||
|
|
@ -66,7 +93,11 @@ class ChatDocument(BaseModel, ModelMixin):
|
|||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that created this document")
|
||||
actionId: Optional[str] = Field(
|
||||
None, description="ID of the action that created this document"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatDocument",
|
||||
{"en": "Chat Document", "fr": "Document de chat"},
|
||||
|
|
@ -84,17 +115,26 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ContentMetadata(BaseModel, ModelMixin):
|
||||
size: int = Field(description="Content size in bytes")
|
||||
pages: Optional[int] = Field(None, description="Number of pages for multi-page content")
|
||||
pages: Optional[int] = Field(
|
||||
None, description="Number of pages for multi-page content"
|
||||
)
|
||||
error: Optional[str] = Field(None, description="Processing error if any")
|
||||
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
|
||||
height: Optional[int] = Field(None, description="Height in pixels for images/videos")
|
||||
height: Optional[int] = Field(
|
||||
None, description="Height in pixels for images/videos"
|
||||
)
|
||||
colorMode: Optional[str] = Field(None, description="Color mode")
|
||||
fps: Optional[float] = Field(None, description="Frames per second for videos")
|
||||
durationSec: Optional[float] = Field(None, description="Duration in seconds for media")
|
||||
durationSec: Optional[float] = Field(
|
||||
None, description="Duration in seconds for media"
|
||||
)
|
||||
mimeType: str = Field(description="MIME type of the content")
|
||||
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ContentMetadata",
|
||||
{"en": "Content Metadata", "fr": "Métadonnées du contenu"},
|
||||
|
|
@ -112,10 +152,13 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ContentItem(BaseModel, ModelMixin):
|
||||
label: str = Field(description="Content label")
|
||||
data: str = Field(description="Extracted text content")
|
||||
metadata: ContentMetadata = Field(description="Content metadata")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ContentItem",
|
||||
{"en": "Content Item", "fr": "Élément de contenu"},
|
||||
|
|
@ -126,9 +169,14 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatContentExtracted(BaseModel, ModelMixin):
|
||||
id: str = Field(description="Reference to source ChatDocument")
|
||||
contents: List[ContentItem] = Field(default_factory=list, description="List of content items")
|
||||
contents: List[ContentItem] = Field(
|
||||
default_factory=list, description="List of content items"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatContentExtracted",
|
||||
{"en": "Extracted Content", "fr": "Contenu extrait"},
|
||||
|
|
@ -138,28 +186,58 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading")
|
||||
documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents")
|
||||
documentsLabel: Optional[str] = Field(None, description="Label for the set of documents")
|
||||
parentMessageId: Optional[str] = Field(
|
||||
None, description="Parent message ID for threading"
|
||||
)
|
||||
documents: List[ChatDocument] = Field(
|
||||
default_factory=list, description="Associated documents"
|
||||
)
|
||||
documentsLabel: Optional[str] = Field(
|
||||
None, description="Label for the set of documents"
|
||||
)
|
||||
message: Optional[str] = Field(None, description="Message content")
|
||||
summary: Optional[str] = Field(None, description="Short summary of this message for planning/history")
|
||||
summary: Optional[str] = Field(
|
||||
None, description="Short summary of this message for planning/history"
|
||||
)
|
||||
role: str = Field(description="Role of the message sender")
|
||||
status: str = Field(description="Status of the message (first, step, last)")
|
||||
sequenceNr: int = Field(description="Sequence number of the message (set automatically)")
|
||||
publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)")
|
||||
sequenceNr: int = Field(
|
||||
description="Sequence number of the message (set automatically)"
|
||||
)
|
||||
publishedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the message was published (UTC timestamp in seconds)",
|
||||
)
|
||||
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
|
||||
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that produced this message")
|
||||
actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message")
|
||||
actionName: Optional[str] = Field(None, description="Name of the action that produced this message")
|
||||
success: Optional[bool] = Field(
|
||||
None, description="Whether the message processing was successful"
|
||||
)
|
||||
actionId: Optional[str] = Field(
|
||||
None, description="ID of the action that produced this message"
|
||||
)
|
||||
actionMethod: Optional[str] = Field(
|
||||
None, description="Method of the action that produced this message"
|
||||
)
|
||||
actionName: Optional[str] = Field(
|
||||
None, description="Name of the action that produced this message"
|
||||
)
|
||||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry")
|
||||
actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail")
|
||||
taskProgress: Optional[str] = Field(
|
||||
None, description="Task progress status: pending, running, success, fail, retry"
|
||||
)
|
||||
actionProgress: Optional[str] = Field(
|
||||
None, description="Action progress status: pending, running, success, fail"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessage",
|
||||
{"en": "Chat Message", "fr": "Message de chat"},
|
||||
|
|
@ -188,32 +266,139 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatWorkflow(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
||||
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
||||
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
||||
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
|
||||
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
|
||||
])
|
||||
name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
||||
currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
||||
{"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}},
|
||||
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
|
||||
])
|
||||
maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False)
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this workflow belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
status: str = Field(
|
||||
description="Current status of the workflow",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
||||
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
||||
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
|
||||
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
|
||||
],
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True,
|
||||
)
|
||||
currentRound: int = Field(
|
||||
description="Current round number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
currentTask: int = Field(
|
||||
default=0,
|
||||
description="Current task number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
currentAction: int = Field(
|
||||
default=0,
|
||||
description="Current action number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
totalTasks: int = Field(
|
||||
default=0,
|
||||
description="Total number of tasks in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
totalActions: int = Field(
|
||||
default=0,
|
||||
description="Total number of actions in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
lastActivity: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Timestamp of last activity (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
startedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the workflow started (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
logs: List[ChatLog] = Field(
|
||||
default_factory=list,
|
||||
description="Workflow logs",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
messages: List[ChatMessage] = Field(
|
||||
default_factory=list,
|
||||
description="Messages in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
stats: Optional[ChatStat] = Field(
|
||||
None,
|
||||
description="Workflow statistics",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
tasks: list = Field(
|
||||
default_factory=list,
|
||||
description="List of tasks in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
workflowMode: str = Field(
|
||||
default="Actionplan",
|
||||
description="Workflow mode selector",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{
|
||||
"value": "Actionplan",
|
||||
"label": {"en": "Action Plan", "fr": "Plan d'actions"},
|
||||
},
|
||||
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
|
||||
],
|
||||
)
|
||||
maxSteps: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of iterations in react mode",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatWorkflow",
|
||||
{"en": "Chat Workflow", "fr": "Flux de travail de chat"},
|
||||
|
|
@ -238,12 +423,16 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class UserInputRequest(BaseModel, ModelMixin):
|
||||
prompt: str = Field(description="Prompt for the user")
|
||||
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
|
||||
userLanguage: str = Field(default="en", description="User's preferred language")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"UserInputRequest", {"en": "User Input Request", "fr": "Demande de saisie utilisateur"},
|
||||
"UserInputRequest",
|
||||
{"en": "User Input Request", "fr": "Demande de saisie utilisateur"},
|
||||
{
|
||||
"prompt": {"en": "Prompt", "fr": "Invite"},
|
||||
"listFileId": {"en": "File IDs", "fr": "IDs des fichiers"},
|
||||
|
|
@ -251,11 +440,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionDocument(BaseModel, ModelMixin):
|
||||
"""Clear document structure for action results"""
|
||||
|
||||
documentName: str = Field(description="Name of the document")
|
||||
documentData: Any = Field(description="Content/data of the document")
|
||||
mimeType: str = Field(description="MIME type of the document")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionDocument",
|
||||
{"en": "Action Document", "fr": "Document d'action"},
|
||||
|
|
@ -266,6 +459,7 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionResult(BaseModel, ModelMixin):
|
||||
"""Clean action result with documents as primary output
|
||||
|
||||
|
|
@ -276,16 +470,25 @@ class ActionResult(BaseModel, ModelMixin):
|
|||
|
||||
success: bool = Field(description="Whether execution succeeded")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs")
|
||||
resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)")
|
||||
documents: List[ActionDocument] = Field(
|
||||
default_factory=list, description="Document outputs"
|
||||
)
|
||||
resultLabel: Optional[str] = Field(
|
||||
None,
|
||||
description="Label for document routing (set by action handler, not by action methods)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult":
|
||||
return cls(success=True, documents=documents or [])
|
||||
|
||||
@classmethod
|
||||
def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> "ActionResult":
|
||||
def isFailure(
|
||||
cls, error: str, documents: List[ActionDocument] = None
|
||||
) -> "ActionResult":
|
||||
return cls(success=False, documents=documents or [], error=error)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionResult",
|
||||
{"en": "Action Result", "fr": "Résultat de l'action"},
|
||||
|
|
@ -297,9 +500,14 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionSelection(BaseModel, ModelMixin):
|
||||
method: str = Field(description="Method to execute (e.g., web, document, ai)")
|
||||
name: str = Field(description="Action name within the method (e.g., search, extract)")
|
||||
name: str = Field(
|
||||
description="Action name within the method (e.g., search, extract)"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionSelection",
|
||||
{"en": "Action Selection", "fr": "Sélection d'action"},
|
||||
|
|
@ -309,8 +517,13 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionParameters(BaseModel, ModelMixin):
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to execute the selected action")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Parameters to execute the selected action"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionParameters",
|
||||
{"en": "Action Parameters", "fr": "Paramètres d'action"},
|
||||
|
|
@ -319,10 +532,13 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ObservationPreview(BaseModel, ModelMixin):
|
||||
name: str = Field(description="Document name or URL label")
|
||||
mime: str = Field(description="MIME type or kind")
|
||||
snippet: str = Field(description="Short snippet or summary")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ObservationPreview",
|
||||
{"en": "Observation Preview", "fr": "Aperçu d'observation"},
|
||||
|
|
@ -333,12 +549,19 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class Observation(BaseModel, ModelMixin):
|
||||
success: bool = Field(description="Action execution success flag")
|
||||
resultLabel: str = Field(description="Deterministic label for produced documents")
|
||||
documentsCount: int = Field(description="Number of produced documents")
|
||||
previews: List[ObservationPreview] = Field(default_factory=list, description="Compact previews of outputs")
|
||||
notes: List[str] = Field(default_factory=list, description="Short notes or key facts")
|
||||
previews: List[ObservationPreview] = Field(
|
||||
default_factory=list, description="Compact previews of outputs"
|
||||
)
|
||||
notes: List[str] = Field(
|
||||
default_factory=list, description="Short notes or key facts"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"Observation",
|
||||
{"en": "Observation", "fr": "Observation"},
|
||||
|
|
@ -351,12 +574,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
class TaskStatus(str):
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskStatus",
|
||||
{"en": "Task Status", "fr": "Statut de la tâche"},
|
||||
|
|
@ -369,9 +595,14 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class DocumentExchange(BaseModel, ModelMixin):
|
||||
documentsLabel: str = Field(description="Label for the set of documents")
|
||||
documents: List[str] = Field(default_factory=list, description="List of document references")
|
||||
documents: List[str] = Field(
|
||||
default_factory=list, description="List of document references"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"DocumentExchange",
|
||||
{"en": "Document Exchange", "fr": "Échange de documents"},
|
||||
|
|
@ -381,20 +612,33 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionItem(BaseModel, ModelMixin):
|
||||
id: str = Field(..., description="Action ID")
|
||||
execMethod: str = Field(..., description="Method to execute")
|
||||
execAction: str = Field(..., description="Action to perform")
|
||||
execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters")
|
||||
execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents")
|
||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)")
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
execParameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Action parameters"
|
||||
)
|
||||
execResultLabel: Optional[str] = Field(
|
||||
None, description="Label for the set of result documents"
|
||||
)
|
||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(
|
||||
None, description="Expected document formats (optional)"
|
||||
)
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status")
|
||||
error: Optional[str] = Field(None, description="Error message if action failed")
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
timestamp: float = Field(..., description="When the action was executed (UTC timestamp in seconds)")
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Processing time in seconds"
|
||||
)
|
||||
timestamp: float = Field(
|
||||
..., description="When the action was executed (UTC timestamp in seconds)"
|
||||
)
|
||||
result: Optional[str] = Field(None, description="Result of the action")
|
||||
|
||||
def setSuccess(self, result: str = None) -> None:
|
||||
|
|
@ -408,6 +652,8 @@ class ActionItem(BaseModel, ModelMixin):
|
|||
"""Set the action as failed with error message"""
|
||||
self.status = TaskStatus.FAILED
|
||||
self.error = error_message
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionItem",
|
||||
{"en": "Task Action", "fr": "Action de tâche"},
|
||||
|
|
@ -417,7 +663,10 @@ register_model_labels(
|
|||
"execAction": {"en": "Action", "fr": "Action"},
|
||||
"execParameters": {"en": "Parameters", "fr": "Paramètres"},
|
||||
"execResultLabel": {"en": "Result Label", "fr": "Label du résultat"},
|
||||
"expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"},
|
||||
"expectedDocumentFormats": {
|
||||
"en": "Expected Document Formats",
|
||||
"fr": "Formats de documents attendus",
|
||||
},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
|
|
@ -429,12 +678,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskResult(BaseModel, ModelMixin):
|
||||
taskId: str = Field(..., description="Task ID")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||
success: bool = Field(..., description="Whether the task was successful")
|
||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskResult",
|
||||
{"en": "Task Result", "fr": "Résultat de tâche"},
|
||||
|
|
@ -447,22 +699,39 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskItem(BaseModel, ModelMixin):
|
||||
id: str = Field(..., description="Task ID")
|
||||
workflowId: str = Field(..., description="Workflow ID")
|
||||
userInput: str = Field(..., description="User input that triggered the task")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||
startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)")
|
||||
finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)")
|
||||
actionList: List[ActionItem] = Field(default_factory=list, description="List of actions to execute")
|
||||
startedAt: Optional[float] = Field(
|
||||
None, description="When the task started (UTC timestamp in seconds)"
|
||||
)
|
||||
finishedAt: Optional[float] = Field(
|
||||
None, description="When the task finished (UTC timestamp in seconds)"
|
||||
)
|
||||
actionList: List[ActionItem] = Field(
|
||||
default_factory=list, description="List of actions to execute"
|
||||
)
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure")
|
||||
dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on")
|
||||
rollbackOnFailure: bool = Field(
|
||||
default=True, description="Whether to rollback on failure"
|
||||
)
|
||||
dependencies: List[str] = Field(
|
||||
default_factory=list, description="List of task IDs this task depends on"
|
||||
)
|
||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||
processingTime: Optional[float] = Field(None, description="Total processing time in seconds")
|
||||
resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values")
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Total processing time in seconds"
|
||||
)
|
||||
resultLabels: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Map of result labels to their values"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskItem",
|
||||
{"en": "Task", "fr": "Tâche"},
|
||||
|
|
@ -481,13 +750,18 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskStep(BaseModel, ModelMixin):
|
||||
id: str
|
||||
objective: str
|
||||
dependencies: Optional[list[str]] = Field(default_factory=list)
|
||||
success_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
estimated_complexity: Optional[str] = None
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskStep",
|
||||
{"en": "Task Step", "fr": "Étape de tâche"},
|
||||
|
|
@ -496,23 +770,45 @@ register_model_labels(
|
|||
"objective": {"en": "Objective", "fr": "Objectif"},
|
||||
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
|
||||
"success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"},
|
||||
"estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"},
|
||||
"estimated_complexity": {
|
||||
"en": "Estimated Complexity",
|
||||
"fr": "Complexité estimée",
|
||||
},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskHandover(BaseModel, ModelMixin):
|
||||
taskId: str = Field(description="Target task ID")
|
||||
sourceTask: Optional[str] = Field(None, description="Source task ID")
|
||||
inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents")
|
||||
outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents")
|
||||
inputDocuments: List[DocumentExchange] = Field(
|
||||
default_factory=list, description="Available input documents"
|
||||
)
|
||||
outputDocuments: List[DocumentExchange] = Field(
|
||||
default_factory=list, description="Produced output documents"
|
||||
)
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="Task context")
|
||||
previousResults: List[str] = Field(default_factory=list, description="Previous result summaries")
|
||||
improvements: List[str] = Field(default_factory=list, description="Improvement suggestions")
|
||||
workflowSummary: Optional[str] = Field(None, description="Summarized workflow context")
|
||||
messageHistory: List[str] = Field(default_factory=list, description="Key message summaries")
|
||||
timestamp: float = Field(..., description="When the handover was created (UTC timestamp in seconds)")
|
||||
handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow")
|
||||
previousResults: List[str] = Field(
|
||||
default_factory=list, description="Previous result summaries"
|
||||
)
|
||||
improvements: List[str] = Field(
|
||||
default_factory=list, description="Improvement suggestions"
|
||||
)
|
||||
workflowSummary: Optional[str] = Field(
|
||||
None, description="Summarized workflow context"
|
||||
)
|
||||
messageHistory: List[str] = Field(
|
||||
default_factory=list, description="Key message summaries"
|
||||
)
|
||||
timestamp: float = Field(
|
||||
..., description="When the handover was created (UTC timestamp in seconds)"
|
||||
)
|
||||
handoverType: str = Field(
|
||||
default="task", description="Type of handover: task, phase, or workflow"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskHandover",
|
||||
{"en": "Task Handover", "fr": "Transfert de tâche"},
|
||||
|
|
@ -531,9 +827,10 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
workflow: Optional['ChatWorkflow'] = None
|
||||
workflow: Optional["ChatWorkflow"] = None
|
||||
workflow_id: Optional[str] = None
|
||||
available_documents: Optional[str] = "No documents available"
|
||||
available_connections: Optional[list[str]] = Field(default_factory=list)
|
||||
|
|
@ -562,6 +859,7 @@ class TaskContext(BaseModel, ModelMixin):
|
|||
self.improvements = []
|
||||
self.improvements.append(improvement)
|
||||
|
||||
|
||||
class ReviewContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
task_actions: Optional[list] = Field(default_factory=list)
|
||||
|
|
@ -570,6 +868,7 @@ class ReviewContext(BaseModel, ModelMixin):
|
|||
workflow_id: Optional[str] = None
|
||||
previous_results: Optional[list[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ReviewResult(BaseModel, ModelMixin):
|
||||
status: str
|
||||
reason: Optional[str] = None
|
||||
|
|
@ -579,7 +878,11 @@ class ReviewResult(BaseModel, ModelMixin):
|
|||
met_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
confidence: Optional[float] = 0.5
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ReviewResult",
|
||||
{"en": "Review Result", "fr": "Résultat de l'évaluation"},
|
||||
|
|
@ -596,10 +899,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskPlan(BaseModel, ModelMixin):
|
||||
overview: str
|
||||
tasks: list[TaskStep]
|
||||
userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="Overall user-friendly message for the task plan"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskPlan",
|
||||
{"en": "Task Plan", "fr": "Plan de tâches"},
|
||||
|
|
@ -613,10 +921,16 @@ register_model_labels(
|
|||
# Resolve forward references
|
||||
TaskContext.update_forward_refs()
|
||||
|
||||
|
||||
class PromptPlaceholder(BaseModel, ModelMixin):
|
||||
label: str
|
||||
content: str
|
||||
summaryAllowed: bool = Field(default=False, description="Whether host may summarize content before sending to AI")
|
||||
summaryAllowed: bool = Field(
|
||||
default=False,
|
||||
description="Whether host may summarize content before sending to AI",
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"PromptPlaceholder",
|
||||
{"en": "Prompt Placeholder", "fr": "Espace réservé d'invite"},
|
||||
|
|
@ -627,11 +941,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class PromptBundle(BaseModel, ModelMixin):
|
||||
prompt: str
|
||||
placeholders: List[PromptPlaceholder] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"PromptBundle", {"en": "Prompt Bundle", "fr": "Lot d'invite"},
|
||||
"PromptBundle",
|
||||
{"en": "Prompt Bundle", "fr": "Lot d'invite"},
|
||||
{
|
||||
"prompt": {"en": "Prompt", "fr": "Invite"},
|
||||
"placeholders": {"en": "Placeholders", "fr": "Espaces réservés"},
|
||||
|
|
|
|||
216
modules/datamodels/datamodelChatbot.py
Normal file
216
modules/datamodels/datamodelChatbot.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""Chatbot API models for request/response handling."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
|
||||
|
||||
# Chatbot API Models
|
||||
class MessageItem(BaseModel, ModelMixin):
|
||||
"""Individual message in a thread"""
|
||||
|
||||
role: str = Field(..., description="Message role (user or assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel, ModelMixin):
|
||||
"""Request model for posting a chat message"""
|
||||
|
||||
thread_id: Optional[str] = Field(
|
||||
None, description="Thread ID (creates new thread if not provided)"
|
||||
)
|
||||
message: str = Field(..., description="User message content")
|
||||
tools: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of tool IDs to use. If not provided, all user's tools will be used",
|
||||
)
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel, ModelMixin):
|
||||
"""Response model for posting a chat message"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
||||
|
||||
|
||||
class ThreadSummary(BaseModel, ModelMixin):
|
||||
"""Summary of a chat thread for list view"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
thread_name: str = Field(..., description="Thread name")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
|
||||
|
||||
class ThreadListResponse(BaseModel, ModelMixin):
|
||||
"""Response model for listing all threads"""
|
||||
|
||||
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
||||
|
||||
|
||||
class ThreadDetail(BaseModel, ModelMixin):
|
||||
"""Detailed view of a single thread"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
messages: List[MessageItem] = Field(
|
||||
..., description="All messages in chronological order"
|
||||
)
|
||||
|
||||
|
||||
class RenameThreadRequest(BaseModel, ModelMixin):
|
||||
"""Request model for renaming a thread"""
|
||||
|
||||
new_name: str = Field(..., description="New name for the thread")
|
||||
|
||||
|
||||
class DeleteResponse(BaseModel, ModelMixin):
|
||||
"""Response model for delete operations"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
thread_id: str = Field(..., description="Deleted thread ID")
|
||||
|
||||
|
||||
# Tool Management Models
|
||||
class ToolInfo(BaseModel, ModelMixin):
|
||||
"""Information about a chatbot tool"""
|
||||
|
||||
id: str = Field(..., description="Tool UUID")
|
||||
tool_id: str = Field(
|
||||
..., description="Tool identifier (e.g., 'shared.tavily_search')"
|
||||
)
|
||||
name: str = Field(..., description="Tool function name")
|
||||
label: str = Field(..., description="Display label for the tool")
|
||||
category: str = Field(..., description="Tool category (shared or customer)")
|
||||
description: str = Field(..., description="Tool description")
|
||||
is_active: bool = Field(..., description="Whether the tool is active")
|
||||
date_created: float = Field(..., description="Creation timestamp")
|
||||
date_updated: float = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel, ModelMixin):
|
||||
"""Response model for listing all tools"""
|
||||
|
||||
tools: List[ToolInfo] = Field(..., description="List of available tools")
|
||||
|
||||
|
||||
class GrantToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for granting a tool to a user"""
|
||||
|
||||
user_id: str = Field(..., description="User ID to grant the tool to")
|
||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
||||
|
||||
|
||||
class GrantToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after granting a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
|
||||
|
||||
class RevokeToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for revoking a tool from a user"""
|
||||
|
||||
user_id: str = Field(..., description="User ID to revoke the tool from")
|
||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
||||
|
||||
|
||||
class RevokeToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after revoking a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
|
||||
|
||||
class UpdateToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for updating a tool's label and description"""
|
||||
|
||||
label: Optional[str] = Field(None, description="New label for the tool")
|
||||
description: Optional[str] = Field(None, description="New description for the tool")
|
||||
|
||||
|
||||
class UpdateToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after updating a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
updated_fields: List[str] = Field(..., description="List of updated field names")
|
||||
|
||||
|
||||
# Register model labels for internationalization
|
||||
register_model_labels(
|
||||
"MessageItem",
|
||||
{"en": "Message Item", "fr": "Élément de message"},
|
||||
{
|
||||
"role": {"en": "Role", "fr": "Rôle"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageRequest",
|
||||
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageResponse",
|
||||
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadSummary",
|
||||
{"en": "Thread Summary", "fr": "Résumé du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"thread_name": {"en": "Thread Name", "fr": "Nom du fil"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadListResponse",
|
||||
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
|
||||
{
|
||||
"threads": {"en": "Threads", "fr": "Fils"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadDetail",
|
||||
{"en": "Thread Detail", "fr": "Détail du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"RenameThreadRequest",
|
||||
{"en": "Rename Thread Request", "fr": "Demande de renommage de fil"},
|
||||
{
|
||||
"new_name": {"en": "New Name", "fr": "Nouveau nom"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"DeleteResponse",
|
||||
{"en": "Delete Response", "fr": "Réponse de suppression"},
|
||||
{
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
},
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""Security models: Token and AuthEvent."""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from .datamodelUam import AuthAuthority
|
||||
|
|
@ -13,25 +13,44 @@ class TokenStatus(str, Enum):
|
|||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class Token(BaseModel, ModelMixin):
|
||||
id: Optional[str] = None
|
||||
userId: str
|
||||
authority: AuthAuthority
|
||||
connectionId: Optional[str] = Field(None, description="ID of the connection this token belongs to")
|
||||
connectionId: Optional[str] = Field(
|
||||
None, description="ID of the connection this token belongs to"
|
||||
)
|
||||
tokenAccess: str
|
||||
tokenType: str = "bearer"
|
||||
expiresAt: float = Field(description="When the token expires (UTC timestamp in seconds)")
|
||||
expiresAt: float = Field(
|
||||
description="When the token expires (UTC timestamp in seconds)"
|
||||
)
|
||||
tokenRefresh: Optional[str] = None
|
||||
createdAt: Optional[float] = Field(None, description="When the token was created (UTC timestamp in seconds)")
|
||||
status: TokenStatus = Field(default=TokenStatus.ACTIVE, description="Token status: active/revoked")
|
||||
revokedAt: Optional[float] = Field(None, description="When the token was revoked (UTC timestamp in seconds)")
|
||||
revokedBy: Optional[str] = Field(None, description="User ID who revoked the token (admin/self)")
|
||||
createdAt: Optional[float] = Field(
|
||||
None, description="When the token was created (UTC timestamp in seconds)"
|
||||
)
|
||||
status: TokenStatus = Field(
|
||||
default=TokenStatus.ACTIVE, description="Token status: active/revoked"
|
||||
)
|
||||
revokedAt: Optional[float] = Field(
|
||||
None, description="When the token was revoked (UTC timestamp in seconds)"
|
||||
)
|
||||
revokedBy: Optional[str] = Field(
|
||||
None, description="User ID who revoked the token (admin/self)"
|
||||
)
|
||||
reason: Optional[str] = Field(None, description="Optional revocation reason")
|
||||
sessionId: Optional[str] = Field(None, description="Logical session grouping for logout revocation")
|
||||
mandateId: Optional[str] = Field(None, description="Mandate ID for tenant scoping of the token")
|
||||
sessionId: Optional[str] = Field(
|
||||
None, description="Logical session grouping for logout revocation"
|
||||
)
|
||||
mandateId: Optional[str] = Field(
|
||||
None, description="Mandate ID for tenant scoping of the token"
|
||||
)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"Token",
|
||||
{"en": "Token", "fr": "Jeton"},
|
||||
|
|
@ -54,15 +73,64 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class AuthEvent(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the auth event", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userId: str = Field(description="ID of the user this event belongs to", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
eventType: str = Field(description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="Unix timestamp when the event occurred", frontend_type="datetime", frontend_readonly=True, frontend_required=True)
|
||||
ipAddress: Optional[str] = Field(default=None, description="IP address from which the event originated", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userAgent: Optional[str] = Field(default=None, description="User agent string from the request", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
success: bool = Field(default=True, description="Whether the authentication event was successful", frontend_type="boolean", frontend_readonly=True, frontend_required=True)
|
||||
details: Optional[str] = Field(default=None, description="Additional details about the event", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the auth event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user this event belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
eventType: str = Field(
|
||||
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Unix timestamp when the event occurred",
|
||||
frontend_type="datetime",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
ipAddress: Optional[str] = Field(
|
||||
default=None,
|
||||
description="IP address from which the event originated",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userAgent: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User agent string from the request",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True,
|
||||
description="Whether the authentication event was successful",
|
||||
frontend_type="boolean",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
details: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Additional details about the event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"AuthEvent",
|
||||
{"en": "Authentication Event", "fr": "Événement d'authentification"},
|
||||
|
|
@ -77,5 +145,3 @@ register_model_labels(
|
|||
"details": {"en": "Details", "fr": "Détails"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
1
modules/features/chatBot/chatbotTools/__init__.py
Normal file
1
modules/features/chatBot/chatbotTools/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Contains all tools available for the chatbot to use."""
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tools that are shared between multiple customers go here."""
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
"""Althaus Database Query Tool for LangGraph.
|
||||
|
||||
This tool provides database query capabilities for the Althaus database
|
||||
via an external REST API. Only SELECT queries are allowed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _mock_api_call(*, sql_query: str) -> dict:
|
||||
"""Mock the external REST API call to Althaus database.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute
|
||||
|
||||
Returns:
|
||||
A dictionary containing the query results with columns and rows
|
||||
"""
|
||||
# Simulate network delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Mock response data based on common query patterns
|
||||
if "users" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["id", "username", "email", "created_at"],
|
||||
"rows": [
|
||||
[1, "john_doe", "john@example.com", "2024-01-15"],
|
||||
[2, "jane_smith", "jane@example.com", "2024-02-20"],
|
||||
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "products" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["product_id", "name", "price", "stock"],
|
||||
"rows": [
|
||||
[101, "Widget A", 29.99, 150],
|
||||
[102, "Widget B", 39.99, 75],
|
||||
[103, "Widget C", 19.99, 200],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "orders" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["order_id", "customer_id", "total", "status"],
|
||||
"rows": [
|
||||
[5001, 1, 129.99, "completed"],
|
||||
[5002, 2, 89.50, "pending"],
|
||||
[5003, 1, 199.99, "shipped"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
else:
|
||||
# Generic response for other queries
|
||||
return {
|
||||
"columns": ["id", "value", "description"],
|
||||
"rows": [
|
||||
[1, "Sample 1", "First sample entry"],
|
||||
[2, "Sample 2", "Second sample entry"],
|
||||
],
|
||||
"row_count": 2,
|
||||
}
|
||||
|
||||
|
||||
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
|
||||
"""Validate that the query is a SELECT statement only.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL query to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Remove leading/trailing whitespace and convert to lowercase for checking
|
||||
normalized_query = sql_query.strip().lower()
|
||||
|
||||
# Check if query starts with SELECT
|
||||
if not normalized_query.startswith("select"):
|
||||
return False, "Query must be a SELECT statement"
|
||||
|
||||
# Check for dangerous keywords that should not be in a SELECT query
|
||||
dangerous_keywords = [
|
||||
"insert",
|
||||
"update",
|
||||
"delete",
|
||||
"drop",
|
||||
"create",
|
||||
"alter",
|
||||
"truncate",
|
||||
"grant",
|
||||
"revoke",
|
||||
"exec",
|
||||
"execute",
|
||||
]
|
||||
|
||||
for keyword in dangerous_keywords:
|
||||
# Use word boundary to match whole words only
|
||||
if re.search(rf"\b{keyword}\b", normalized_query):
|
||||
return False, f"Query contains forbidden keyword: {keyword.upper()}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
|
||||
"""Format query results into a readable string.
|
||||
|
||||
Args:
|
||||
columns: List of column names
|
||||
rows: List of row data
|
||||
row_count: Total number of rows
|
||||
|
||||
Returns:
|
||||
Formatted string representation of the results
|
||||
"""
|
||||
if row_count == 0:
|
||||
return "Query executed successfully but returned no results."
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(str(col)) for col in columns]
|
||||
for row in rows:
|
||||
for i, cell in enumerate(row):
|
||||
col_widths[i] = max(col_widths[i], len(str(cell)))
|
||||
|
||||
# Build header
|
||||
header_parts = []
|
||||
for col, width in zip(columns, col_widths):
|
||||
header_parts.append(str(col).ljust(width))
|
||||
header = " | ".join(header_parts)
|
||||
separator = "-" * len(header)
|
||||
|
||||
# Build rows
|
||||
row_lines = []
|
||||
for row in rows:
|
||||
row_parts = []
|
||||
for cell, width in zip(row, col_widths):
|
||||
row_parts.append(str(cell).ljust(width))
|
||||
row_lines.append(" | ".join(row_parts))
|
||||
|
||||
# Combine all parts
|
||||
result_parts = [
|
||||
f"Query returned {row_count} row(s):\n",
|
||||
header,
|
||||
separator,
|
||||
"\n".join(row_lines),
|
||||
]
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
@tool
|
||||
async def query_althaus_database(
|
||||
sql_query: Annotated[
|
||||
str, "The SQL SELECT query to execute against the Althaus database"
|
||||
],
|
||||
) -> str:
|
||||
"""Execute a SELECT query against the Althaus database via REST API.
|
||||
|
||||
Use this tool to query data from the Althaus database. Only SELECT statements
|
||||
are allowed for security reasons. The query will be forwarded to an external
|
||||
REST API and the results will be returned in a formatted table.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
|
||||
|
||||
Returns:
|
||||
A formatted string containing the query results with columns and rows
|
||||
"""
|
||||
try:
|
||||
# Validate the query
|
||||
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
|
||||
if not is_valid:
|
||||
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
|
||||
|
||||
# Mock the external REST API call
|
||||
# In production, this would be replaced with actual REST API call:
|
||||
# response = await httpx.AsyncClient().post(
|
||||
# "https://api.althaus.example.com/query",
|
||||
# json={"query": sql_query},
|
||||
# headers={"Authorization": f"Bearer {api_key}"}
|
||||
# )
|
||||
# result = response.json()
|
||||
|
||||
result = await _mock_api_call(sql_query=sql_query)
|
||||
|
||||
# Format and return results
|
||||
formatted_output = _format_results(
|
||||
columns=result["columns"],
|
||||
rows=result["rows"],
|
||||
row_count=result["row_count"],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Query completed successfully, returned {result['row_count']} row(s)"
|
||||
)
|
||||
return formatted_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in query_althaus_database tool: {str(e)}")
|
||||
return f"Error executing query: {str(e)}"
|
||||
|
|
@ -0,0 +1,362 @@
|
|||
"""Power BI Query Tool for LangGraph.
|
||||
|
||||
This tool provides DAX query capabilities for Power BI datasets
|
||||
via the Power BI REST API. Only read-only queries are allowed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import functools
|
||||
from typing import Annotated
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from msal import ConfidentialClientApplication, SerializableTokenCache
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration constants - encapsulated in this file
|
||||
POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "")
|
||||
POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "")
|
||||
POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "")
|
||||
POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "")
|
||||
POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg"
|
||||
POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com"
|
||||
POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"]
|
||||
|
||||
# Limit results to prevent excessive context usage
|
||||
MAX_ROWS_LIMIT = 100
|
||||
|
||||
|
||||
def _validate_environment() -> tuple[bool, str]:
|
||||
"""Validate that all required environment variables are set.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
missing = []
|
||||
if not POWERBI_DATASET_ID:
|
||||
missing.append("POWERBI_DATASET_ID")
|
||||
if not POWERBI_CLIENT_ID:
|
||||
missing.append("POWERBI_CLIENT_ID")
|
||||
if not POWERBI_CLIENT_SECRET:
|
||||
missing.append("POWERBI_CLIENT_SECRET")
|
||||
if not POWERBI_TENANT_ID:
|
||||
missing.append("POWERBI_TENANT_ID")
|
||||
|
||||
if missing:
|
||||
return False, f"Missing required environment variables: {', '.join(missing)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]:
|
||||
"""Validate that the query is a valid DAX query.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Remove leading/trailing whitespace
|
||||
normalized_query = dax_query.strip()
|
||||
|
||||
if not normalized_query:
|
||||
return False, "Query cannot be empty"
|
||||
|
||||
# DAX queries typically start with EVALUATE, DEFINE, or are table expressions
|
||||
# We'll be lenient and just check it's not trying to do something dangerous
|
||||
# DAX is read-only by nature, but we validate structure
|
||||
|
||||
# Check for minimum length
|
||||
if len(normalized_query) < 5:
|
||||
return False, "Query is too short to be valid"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _get_access_token_sync(
|
||||
*,
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
authority_base: str = POWERBI_AUTHORITY_BASE,
|
||||
cache: SerializableTokenCache | None = None,
|
||||
) -> str:
|
||||
"""Get Power BI access token using MSAL (synchronous).
|
||||
|
||||
Args:
|
||||
tenant_id: Azure AD tenant ID
|
||||
client_id: Application client ID
|
||||
client_secret: Application client secret
|
||||
authority_base: Azure AD authority base URL
|
||||
cache: Optional token cache for reuse
|
||||
|
||||
Returns:
|
||||
Access token string
|
||||
|
||||
Raises:
|
||||
RuntimeError: If token acquisition fails
|
||||
"""
|
||||
authority = f"{authority_base}/{tenant_id}"
|
||||
|
||||
app = ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
authority=authority,
|
||||
client_credential=client_secret,
|
||||
token_cache=cache,
|
||||
)
|
||||
|
||||
# Try cache first; fall back to client credentials
|
||||
result = app.acquire_token_silent(
|
||||
POWERBI_SCOPE, account=None
|
||||
) or app.acquire_token_for_client(scopes=POWERBI_SCOPE)
|
||||
|
||||
if "access_token" not in result:
|
||||
raise RuntimeError(
|
||||
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
|
||||
)
|
||||
|
||||
return result["access_token"]
|
||||
|
||||
|
||||
async def _get_access_token_async(
|
||||
*,
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Get Power BI access token using MSAL (asynchronous).
|
||||
|
||||
Args:
|
||||
tenant_id: Azure AD tenant ID
|
||||
client_id: Application client ID
|
||||
client_secret: Application client secret
|
||||
**kwargs: Additional arguments for _get_access_token_sync
|
||||
|
||||
Returns:
|
||||
Access token string
|
||||
"""
|
||||
# Create a partial function with arguments pre-filled
|
||||
func = functools.partial(
|
||||
_get_access_token_sync,
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
**kwargs,
|
||||
)
|
||||
# Offload the blocking MSAL HTTP call to a worker thread
|
||||
return await anyio.to_thread.run_sync(func)
|
||||
|
||||
|
||||
async def _execute_dax_query(
|
||||
*, dax_query: str, dataset_id: str, access_token: str
|
||||
) -> dict:
|
||||
"""Execute a DAX query against Power BI dataset.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to execute
|
||||
dataset_id: Power BI dataset ID
|
||||
access_token: Access token for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary containing query results
|
||||
|
||||
Raises:
|
||||
RuntimeError: If query execution fails
|
||||
"""
|
||||
url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries"
|
||||
|
||||
body = {
|
||||
"queries": [{"query": dax_query}],
|
||||
"serializerSettings": {"includeNulls": True},
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(url, headers=headers, json=body)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
|
||||
)
|
||||
|
||||
payload = resp.json()
|
||||
|
||||
try:
|
||||
rows = payload["results"][0]["tables"][0]["rows"]
|
||||
except (KeyError, IndexError) as e:
|
||||
raise RuntimeError("Unexpected executeQueries response structure") from e
|
||||
|
||||
# Extract column names from the first row if available
|
||||
if rows:
|
||||
columns = list(rows[0].keys())
|
||||
else:
|
||||
columns = []
|
||||
|
||||
return {"columns": columns, "rows": rows}
|
||||
|
||||
|
||||
def _strip_table_qualifier(*, column_name: str) -> str:
|
||||
"""Strip table qualifier from column name.
|
||||
|
||||
Power BI often returns columns as 'Table[Column]'. This strips to 'Column'.
|
||||
|
||||
Args:
|
||||
column_name: The column name to process
|
||||
|
||||
Returns:
|
||||
Processed column name
|
||||
"""
|
||||
if "[" in column_name and column_name.endswith("]"):
|
||||
return column_name.split("[", 1)[1][:-1]
|
||||
return column_name
|
||||
|
||||
|
||||
def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str:
|
||||
"""Format query results into a readable string.
|
||||
|
||||
Args:
|
||||
columns: List of column names
|
||||
rows: List of row data (as dictionaries)
|
||||
max_rows: Maximum number of rows to display
|
||||
|
||||
Returns:
|
||||
Formatted string representation of the results
|
||||
"""
|
||||
total_rows = len(rows)
|
||||
|
||||
if total_rows == 0:
|
||||
return "Query executed successfully but returned no results."
|
||||
|
||||
# Strip table qualifiers from column names
|
||||
clean_columns = [_strip_table_qualifier(column_name=col) for col in columns]
|
||||
|
||||
# Limit rows to max_rows
|
||||
display_rows = rows[:max_rows]
|
||||
truncated = total_rows > max_rows
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(str(col)) for col in clean_columns]
|
||||
for row in display_rows:
|
||||
for i, col in enumerate(columns):
|
||||
value = row.get(col, "")
|
||||
col_widths[i] = max(col_widths[i], len(str(value)))
|
||||
|
||||
# Build header
|
||||
header_parts = []
|
||||
for col, width in zip(clean_columns, col_widths):
|
||||
header_parts.append(str(col).ljust(width))
|
||||
header = " | ".join(header_parts)
|
||||
separator = "-" * len(header)
|
||||
|
||||
# Build rows
|
||||
row_lines = []
|
||||
for row in display_rows:
|
||||
row_parts = []
|
||||
for col, width in zip(columns, col_widths):
|
||||
value = row.get(col, "")
|
||||
row_parts.append(str(value).ljust(width))
|
||||
row_lines.append(" | ".join(row_parts))
|
||||
|
||||
# Combine all parts
|
||||
result_parts = [
|
||||
f"Query returned {total_rows} row(s):",
|
||||
]
|
||||
|
||||
if truncated:
|
||||
result_parts.append(
|
||||
f"(Results limited to {max_rows} rows for context efficiency)\n"
|
||||
)
|
||||
else:
|
||||
result_parts.append("")
|
||||
|
||||
result_parts.extend([header, separator, "\n".join(row_lines)])
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
@tool
|
||||
async def query_powerbi_data(
|
||||
dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"],
|
||||
) -> str:
|
||||
"""Execute a DAX query against the Power BI dataset to access warehouse inventory data.
|
||||
|
||||
This tool provides access to a Power BI table called 'data_full' which contains
|
||||
articles available in the warehouse of the user. Use DAX (Data Analysis Expressions)
|
||||
queries to retrieve and analyze this inventory data.
|
||||
|
||||
Available table:
|
||||
- 'data_full': Contains warehouse inventory articles and their details
|
||||
|
||||
Common query patterns:
|
||||
- View all data: EVALUATE 'data_full'
|
||||
- With filter: EVALUATE FILTER('data_full', [Column] = "Value")
|
||||
- Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC)
|
||||
- Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2]))
|
||||
|
||||
Results are limited to 100 rows maximum for efficiency.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'")
|
||||
|
||||
Returns:
|
||||
A formatted string containing the query results with columns and rows
|
||||
"""
|
||||
try:
|
||||
# Validate environment configuration
|
||||
is_valid_env, error_msg = _validate_environment()
|
||||
if not is_valid_env:
|
||||
logger.error(f"Environment validation failed: {error_msg}")
|
||||
return f"Configuration Error: {error_msg}"
|
||||
|
||||
# Validate the query
|
||||
is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query)
|
||||
if not is_valid_query:
|
||||
logger.warning(f"Invalid query attempt: {dax_query[:100]}...")
|
||||
return f"Query Validation Error: {error_msg}"
|
||||
|
||||
logger.info(f"Executing Power BI query: {dax_query[:100]}...")
|
||||
|
||||
# Get access token
|
||||
access_token = await _get_access_token_async(
|
||||
tenant_id=POWERBI_TENANT_ID,
|
||||
client_id=POWERBI_CLIENT_ID,
|
||||
client_secret=POWERBI_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await _execute_dax_query(
|
||||
dax_query=dax_query,
|
||||
dataset_id=POWERBI_DATASET_ID,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Format and return results
|
||||
formatted_output = _format_results(
|
||||
columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Query completed successfully, returned {len(result['rows'])} row(s)"
|
||||
)
|
||||
return formatted_output
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}")
|
||||
return f"Error executing query: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}")
|
||||
return f"Unexpected error: {str(e)}"
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Shared tools available across all chatbot implementations."""
|
||||
|
||||
from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import (
|
||||
tavily_search,
|
||||
)
|
||||
|
||||
__all__ = ["tavily_search"]
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Tool for sending streaming status updates to users."""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def send_streaming_message(message: str) -> str:
|
||||
"""Send a streaming message to the user to provide updates during processing.
|
||||
|
||||
Use this tool to send short status updates to the user while you are working
|
||||
on their request. This helps keep the user informed about what you are doing.
|
||||
|
||||
Args:
|
||||
message: A short message describing what you are currently doing.
|
||||
Examples: "Searching database for relevant information..."
|
||||
"Analyzing search results..."
|
||||
"Processing your request..."
|
||||
|
||||
Returns:
|
||||
A confirmation that the message was sent.
|
||||
"""
|
||||
# This tool doesn't actually do anything - it's just for the AI to signal
|
||||
# what it's doing to the frontend via the tool call mechanism
|
||||
return f"Status update sent: {message}"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""Tavily Search Tool for LangGraph.
|
||||
|
||||
This tool provides web search capabilities using the Tavily API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
from modules.connectors.connectorAiTavily import ConnectorWeb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
async def tavily_search(
|
||||
query: Annotated[str, "The search query to look up on the web"],
|
||||
) -> str:
|
||||
"""Search the web using Tavily API.
|
||||
|
||||
Use this tool to search for current information, news, or any web content.
|
||||
The tool returns relevant search results including titles and URLs.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
|
||||
Returns:
|
||||
A formatted string containing search results with titles and URLs
|
||||
"""
|
||||
try:
|
||||
# Create connector instance
|
||||
connector = await ConnectorWeb.create()
|
||||
|
||||
# Perform search with default parameters
|
||||
results = await connector._search(
|
||||
query=query,
|
||||
max_results=5,
|
||||
search_depth="basic",
|
||||
include_answer=True,
|
||||
include_raw_content=False,
|
||||
)
|
||||
|
||||
# Format results
|
||||
if not results:
|
||||
return f"No results found for query: {query}"
|
||||
|
||||
formatted_results = [f"Search results for '{query}':\n"]
|
||||
for i, result in enumerate(results, 1):
|
||||
formatted_results.append(f"{i}. {result.title}")
|
||||
formatted_results.append(f" URL: {result.url}\n")
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tavily_search tool: {str(e)}")
|
||||
return f"Error performing search: {str(e)}"
|
||||
197
modules/features/chatBot/database.py
Normal file
197
modules/features/chatBot/database.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
from typing import AsyncIterator
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
# Tools Table
|
||||
class Tool(Base):
|
||||
"""Available chatbot tools.
|
||||
|
||||
Stores information about all available tools that can be assigned to users.
|
||||
Each tool has a unique tool_id that corresponds to the registry tool_id.
|
||||
"""
|
||||
|
||||
__tablename__ = "tools"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(1000), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
date_created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# User-Tool Mapping Table
|
||||
class UserToolMapping(Base):
|
||||
"""Mapping of users to their available tools.
|
||||
|
||||
Many-to-many relationship between users and tools.
|
||||
- One user can have multiple tools
|
||||
- One tool can be assigned to multiple users
|
||||
|
||||
The combination of user_id and tool_id is unique.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_tools"
|
||||
__table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
date_granted: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# User Thread Mapping Table
|
||||
class UserThreadMapping(Base):
|
||||
"""Mapping of users to their chat threads.
|
||||
|
||||
Used to keep track of which user owns which chat thread.
|
||||
Also stores meta data like thread name.
|
||||
|
||||
1:N relationship between user and thread. A thread belongs to exactly one user.
|
||||
A user can have multiple threads.
|
||||
Thread_id is unique in the table.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_threads"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
thread_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
date_created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# Dependency that pulls the sessionmaker off app.state
|
||||
# This is set in app.py on startup in @asynccontextmanager
|
||||
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
|
||||
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
SessionLocal: async_sessionmaker[AsyncSession] = (
|
||||
request.app.state.checkpoint_sessionmaker
|
||||
)
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Optional helper to init tables at startup (demo only)
|
||||
async def init_models(engine) -> None:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def sync_tools_from_registry(session: AsyncSession) -> None:
|
||||
"""Sync tools from tool registry to database.
|
||||
|
||||
This function:
|
||||
- Adds new tools from the registry to the database
|
||||
- Updates existing tools with current registry information
|
||||
- Marks tools not present in the registry as inactive
|
||||
|
||||
Should be called on application startup after database initialization.
|
||||
|
||||
Args:
|
||||
session: Active database session
|
||||
"""
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Syncing tools from registry to database...")
|
||||
|
||||
# Get all tools from the registry
|
||||
registry = get_registry()
|
||||
registry_tools = registry.get_all_tools()
|
||||
|
||||
# Create a set of tool_ids from the registry
|
||||
registry_tool_ids = {tool.tool_id for tool in registry_tools}
|
||||
|
||||
logger.info(f"Found {len(registry_tools)} tools in registry")
|
||||
|
||||
# Get all existing tools from the database
|
||||
result = await session.execute(select(Tool))
|
||||
db_tools = result.scalars().all()
|
||||
db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools}
|
||||
|
||||
logger.info(f"Found {len(db_tools)} tools in database")
|
||||
|
||||
# Track changes
|
||||
added_count = 0
|
||||
updated_count = 0
|
||||
deactivated_count = 0
|
||||
|
||||
# Sync tools from registry to database
|
||||
for registry_tool in registry_tools:
|
||||
if registry_tool.tool_id in db_tools_by_tool_id:
|
||||
# Tool exists - update it
|
||||
# Preserve label and description (user-editable fields)
|
||||
db_tool = db_tools_by_tool_id[registry_tool.tool_id]
|
||||
db_tool.name = registry_tool.name
|
||||
db_tool.category = registry_tool.category
|
||||
db_tool.is_active = True
|
||||
db_tool.date_updated = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
logger.debug(f"Updated tool: {registry_tool.tool_id}")
|
||||
else:
|
||||
# Tool doesn't exist - create it
|
||||
new_tool = Tool(
|
||||
tool_id=registry_tool.tool_id,
|
||||
name=registry_tool.name,
|
||||
label=registry_tool.tool_id, # Use tool_id as label per spec
|
||||
category=registry_tool.category,
|
||||
description=registry_tool.description or "",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(new_tool)
|
||||
added_count += 1
|
||||
logger.debug(f"Added new tool: {registry_tool.tool_id}")
|
||||
|
||||
# Mark tools not in registry as inactive
|
||||
for db_tool in db_tools:
|
||||
if db_tool.tool_id not in registry_tool_ids and db_tool.is_active:
|
||||
db_tool.is_active = False
|
||||
db_tool.date_updated = datetime.now(timezone.utc)
|
||||
deactivated_count += 1
|
||||
logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}")
|
||||
|
||||
logger.info(
|
||||
f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated"
|
||||
)
|
||||
1
modules/features/chatBot/domain/__init__.py
Normal file
1
modules/features/chatBot/domain/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Domain logic for chatbot functionality."""
|
||||
301
modules/features/chatBot/domain/chatbot.py
Normal file
301
modules/features/chatBot/domain/chatbot.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
"""Chatbot domain logic with LangGraph integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, AsyncIterator, Any
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatState(BaseModel):
|
||||
"""Represents the state of a chat session."""
|
||||
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
|
||||
|
||||
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
||||
"""Map permission model names to LangChain ChatAnthropic models.
|
||||
|
||||
Args:
|
||||
model_name: The model name from permissions (e.g., "claude_4_5")
|
||||
|
||||
Returns:
|
||||
Configured ChatAnthropic instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the model name is not supported
|
||||
"""
|
||||
# Model name mapping
|
||||
model_mapping = {
|
||||
"claude_4_5": "claude-sonnet-4-5",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
|
||||
anthropic_model = model_mapping.get(model_name)
|
||||
if not anthropic_model:
|
||||
logger.warning(
|
||||
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
|
||||
)
|
||||
anthropic_model = "claude-4-5-sonnet"
|
||||
|
||||
return ChatAnthropic(
|
||||
model=anthropic_model,
|
||||
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
|
||||
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
|
||||
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chatbot:
|
||||
"""Represents a chatbot with LangGraph integration."""
|
||||
|
||||
model: Any
|
||||
memory: Any
|
||||
app: Any = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
context_window_size: int = 100000
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
*,
|
||||
model: Any,
|
||||
memory: Any,
|
||||
system_prompt: str,
|
||||
tools: list,
|
||||
context_window_size: int = 100000,
|
||||
) -> "Chatbot":
|
||||
"""Factory method to create and configure a Chatbot instance.
|
||||
|
||||
Args:
|
||||
model: The chat model to use.
|
||||
memory: The chat memory checkpointer to use.
|
||||
system_prompt: The system prompt to initialize the chatbot.
|
||||
tools: List of LangChain tools the chatbot can use.
|
||||
context_window_size: Maximum tokens for context window.
|
||||
|
||||
Returns:
|
||||
A configured Chatbot instance.
|
||||
"""
|
||||
instance = cls(
|
||||
model=model,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
instance.app = instance._build_app(memory=memory, tools=tools)
|
||||
return instance
|
||||
|
||||
def _build_app(
|
||||
self, *, memory: Any, tools: list
|
||||
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
||||
"""Builds the chatbot application workflow using LangGraph.
|
||||
|
||||
Args:
|
||||
memory: The chat memory checkpointer to use.
|
||||
tools: The list of tools the chatbot can use.
|
||||
|
||||
Returns:
|
||||
A compiled state graph representing the chatbot application.
|
||||
"""
|
||||
llm_with_tools = self.model.bind_tools(tools=tools)
|
||||
|
||||
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Selects a window of messages that fit within the context window size.
|
||||
|
||||
Args:
|
||||
msgs: The list of messages to select from.
|
||||
|
||||
Returns:
|
||||
A list of messages that fit within the context window size.
|
||||
"""
|
||||
|
||||
def approx_counter(items: list[BaseMessage]) -> int:
|
||||
"""Approximate token counter for messages.
|
||||
|
||||
Args:
|
||||
items: List of messages to count tokens for.
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
"""
|
||||
return sum(len(getattr(m, "content", "") or "") for m in items)
|
||||
|
||||
return trim_messages(
|
||||
msgs,
|
||||
strategy="last",
|
||||
token_counter=approx_counter,
|
||||
max_tokens=self.context_window_size,
|
||||
start_on="human",
|
||||
end_on=("human", "tool"),
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
def agent_node(state: ChatState) -> dict:
|
||||
"""Agent node for the chatbot workflow.
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The updated chat state after processing.
|
||||
"""
|
||||
# Select the message window to fit in context (trim if needed)
|
||||
window = select_window(state.messages)
|
||||
|
||||
# Ensure the system prompt is present at the start
|
||||
if not window or not isinstance(window[0], SystemMessage):
|
||||
window = [SystemMessage(content=self.system_prompt)] + window
|
||||
|
||||
# Call the LLM with tools
|
||||
response = llm_with_tools.invoke(window)
|
||||
|
||||
# Return the new state
|
||||
return {"messages": [response]}
|
||||
|
||||
def should_continue(state: ChatState) -> str:
|
||||
"""Determines whether to continue the workflow or end it.
|
||||
|
||||
This conditional edge is called after the agent node to decide
|
||||
whether to continue to the tools node (if the last message contains
|
||||
tool calls) or to end the workflow (if no tool calls are present).
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The next node to transition to ("tools" or END).
|
||||
"""
|
||||
# Get the last message
|
||||
last_message = state.messages[-1]
|
||||
|
||||
# Check if the last message contains tool calls
|
||||
# If so, continue to the tools node; otherwise, end the workflow
|
||||
return "tools" if getattr(last_message, "tool_calls", None) else END
|
||||
|
||||
# Compose the workflow
|
||||
workflow = StateGraph(ChatState)
|
||||
workflow.add_node("agent", agent_node)
|
||||
workflow.add_node("tools", ToolNode(tools=tools))
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue)
|
||||
workflow.add_edge("tools", "agent")
|
||||
return workflow.compile(checkpointer=memory)
|
||||
|
||||
async def chat(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> list[BaseMessage]:
|
||||
"""Processes a chat message and returns the chat history.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: The chat thread ID.
|
||||
|
||||
Returns:
|
||||
The list of messages in the chat history.
|
||||
"""
|
||||
# Set the right thread ID for memory
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
# Single-turn chat (non-streaming)
|
||||
result = await self.app.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}, config=config
|
||||
)
|
||||
|
||||
# Extract and return the messages from the result
|
||||
return result["messages"]
|
||||
|
||||
async def stream_events(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Stream UI-focused events using astream_events v2.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: Logical thread identifier; forwarded in the runnable config so
|
||||
memory and tools are scoped per thread.
|
||||
|
||||
Yields:
|
||||
dict: One of:
|
||||
- ``{"type": "status", "label": str}`` for short progress updates.
|
||||
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
||||
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
||||
- ``{"type": "error", "message": str}`` if an exception occurs.
|
||||
"""
|
||||
# Thread-aware config for LangGraph/LangChain
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
def _is_root(ev: dict) -> bool:
|
||||
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
||||
return not ev.get("parent_ids")
|
||||
|
||||
try:
|
||||
async for event in self.app.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
config=config,
|
||||
version="v2",
|
||||
):
|
||||
etype = event.get("event")
|
||||
ename = event.get("name") or ""
|
||||
edata = event.get("data") or {}
|
||||
|
||||
# Stream human-readable progress via the special send_streaming_message tool
|
||||
if etype == "on_tool_start" and ename == "send_streaming_message":
|
||||
tool_in = edata.get("input") or {}
|
||||
msg = tool_in.get("message")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
yield {"type": "status", "label": msg.strip()}
|
||||
continue
|
||||
|
||||
# Emit the final payload when the root run finishes
|
||||
if etype == "on_chain_end" and _is_root(event):
|
||||
output_obj = edata.get("output")
|
||||
|
||||
# Extract message list from the graph's final output
|
||||
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
||||
output_obj=output_obj
|
||||
)
|
||||
|
||||
# Normalize for the frontend (only user/assistant with text content)
|
||||
chat_history_payload: list[dict] = []
|
||||
for m in final_msgs:
|
||||
if isinstance(m, BaseMessage):
|
||||
d = ChatStreamingHelper.message_to_dict(msg=m)
|
||||
elif isinstance(m, dict):
|
||||
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
||||
else:
|
||||
continue
|
||||
if d.get("role") in ("user", "assistant") and d.get("content"):
|
||||
chat_history_payload.append(d)
|
||||
|
||||
yield {
|
||||
"type": "final",
|
||||
"response": {
|
||||
"thread": chat_id,
|
||||
"chat_history": chat_history_payload,
|
||||
},
|
||||
}
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
# Emit a single error envelope and end the stream
|
||||
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
|
||||
yield {"type": "error", "message": f"Error processing request: {exc}"}
|
||||
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Streaming helper utilities for chat message processing and normalization."""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
Role = Literal["user", "assistant", "system", "tool"]
|
||||
|
||||
|
||||
class ChatStreamingHelper:
|
||||
"""Pure helper methods for streaming and message normalization.
|
||||
|
||||
This class provides static utility methods for converting between different
|
||||
message formats, extracting content, and normalizing message structures
|
||||
for streaming chat applications.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def role_from_message(*, msg: BaseMessage) -> Role:
|
||||
"""Extract the role from a BaseMessage instance.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to extract the role from.
|
||||
|
||||
Returns:
|
||||
The role as a string literal: "user", "assistant", "system", or "tool".
|
||||
Defaults to "assistant" if the message type is not recognized.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello")
|
||||
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
||||
'user'
|
||||
"""
|
||||
if isinstance(msg, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(msg, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(msg, SystemMessage):
|
||||
return "system"
|
||||
if isinstance(msg, ToolMessage):
|
||||
return "tool"
|
||||
return getattr(msg, "role", "assistant")
|
||||
|
||||
@staticmethod
|
||||
def flatten_content(*, content: Any) -> str:
|
||||
"""Convert complex content structures to plain text.
|
||||
|
||||
This method handles various content formats including strings, lists of
|
||||
content parts, and dictionaries with text fields. It's designed to
|
||||
normalize content from different message sources into a consistent
|
||||
plain text format.
|
||||
|
||||
Args:
|
||||
content: The content to flatten. Can be:
|
||||
- str: Returned as-is after stripping whitespace
|
||||
- list: Each item processed and joined with newlines
|
||||
- dict: Text extracted from "text" or "content" fields
|
||||
- None: Returns empty string
|
||||
- Any other type: Converted to string
|
||||
|
||||
Returns:
|
||||
The flattened content as a plain text string with whitespace stripped.
|
||||
|
||||
Examples:
|
||||
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Hello
|
||||
nworld'
|
||||
|
||||
>>> content = {"text": "Simple message"}
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Simple message'
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if "text" in part and isinstance(part["text"], str):
|
||||
parts.append(part["text"])
|
||||
elif part.get("type") == "text" and isinstance(
|
||||
part.get("text"), str
|
||||
):
|
||||
parts.append(part["text"])
|
||||
elif "content" in part and isinstance(part["content"], str):
|
||||
parts.append(part["content"])
|
||||
else:
|
||||
# Fallback for unknown dictionary structures
|
||||
val = part.get("value")
|
||||
if isinstance(val, str):
|
||||
parts.append(val)
|
||||
else:
|
||||
parts.append(str(part))
|
||||
return "\n".join(p.strip() for p in parts if p is not None)
|
||||
if isinstance(content, dict):
|
||||
if "text" in content and isinstance(content["text"], str):
|
||||
return content["text"].strip()
|
||||
if "content" in content and isinstance(content["content"], str):
|
||||
return content["content"].strip()
|
||||
return str(content).strip()
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
||||
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
||||
|
||||
This method normalizes BaseMessage instances into a consistent dictionary
|
||||
format suitable for JSON serialization and streaming to clients.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to convert.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello there")
|
||||
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello there'
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
||||
"content": ChatStreamingHelper.flatten_content(
|
||||
content=getattr(msg, "content", "")
|
||||
),
|
||||
}
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload["tool_calls"] = tool_calls
|
||||
name = getattr(msg, "name", None)
|
||||
if name:
|
||||
payload["name"] = name
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a dictionary-shaped message to a normalized dictionary.
|
||||
|
||||
This method handles messages that come from serialized state and are
|
||||
represented as dictionaries rather than BaseMessage instances. It
|
||||
normalizes various dictionary formats into a consistent structure.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-shaped message to convert. Expected to contain
|
||||
fields like "role", "type", "content", "text", etc.
|
||||
|
||||
Returns:
|
||||
A normalized dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> obj = {"type": "human", "content": "Hello"}
|
||||
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello'
|
||||
"""
|
||||
role: Optional[str] = obj.get("role")
|
||||
if not role:
|
||||
# Handle alternative type field mappings
|
||||
typ = obj.get("type")
|
||||
if typ in ("human", "user"):
|
||||
role = "user"
|
||||
elif typ in ("ai", "assistant"):
|
||||
role = "assistant"
|
||||
elif typ in ("system",):
|
||||
role = "system"
|
||||
elif typ in ("tool", "function"):
|
||||
role = "tool"
|
||||
|
||||
content = obj.get("content")
|
||||
if content is None and "text" in obj:
|
||||
content = obj["text"]
|
||||
|
||||
out: Dict[str, Any] = {
|
||||
"role": role or "assistant",
|
||||
"content": ChatStreamingHelper.flatten_content(content=content),
|
||||
}
|
||||
if "tool_calls" in obj:
|
||||
out["tool_calls"] = obj["tool_calls"]
|
||||
if obj.get("name"):
|
||||
out["name"] = obj["name"]
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
||||
"""Extract messages from LangGraph output objects.
|
||||
|
||||
This method handles various output formats from LangGraph execution,
|
||||
extracting the messages list from different possible structures.
|
||||
|
||||
Args:
|
||||
output_obj: The output object from LangGraph execution. Can be:
|
||||
- An object with a "messages" attribute
|
||||
- A dictionary with a "messages" key
|
||||
- Any other object (returns empty list)
|
||||
|
||||
Returns:
|
||||
A list of extracted messages, or an empty list if no messages
|
||||
are found or if the output object is None.
|
||||
|
||||
Examples:
|
||||
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
||||
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
||||
>>> len(messages)
|
||||
1
|
||||
"""
|
||||
if output_obj is None:
|
||||
return []
|
||||
|
||||
# Try to parse dicts first
|
||||
if isinstance(output_obj, dict):
|
||||
msgs = output_obj.get("messages")
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
|
||||
# Then try to get messages attribute
|
||||
msgs = getattr(output_obj, "messages", None)
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
1022
modules/features/chatBot/service.py
Normal file
1022
modules/features/chatBot/service.py
Normal file
File diff suppressed because it is too large
Load diff
106
modules/features/chatBot/utils/checkpointer.py
Normal file
106
modules/features/chatBot/utils/checkpointer.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""PostgreSQL checkpointer utilities for LangGraph memory."""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply)
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
except RuntimeError:
|
||||
pass # Already set
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.rows import dict_row
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpointer instance
|
||||
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
|
||||
_connection_pool: Optional[AsyncConnectionPool] = None
|
||||
|
||||
|
||||
async def initialize_checkpointer() -> None:
|
||||
"""Initialize the PostgreSQL checkpointer for LangGraph.
|
||||
|
||||
This should be called during application startup.
|
||||
Creates a connection pool and PostgresSaver instance.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _checkpointer_instance is not None:
|
||||
logger.info("Checkpointer already initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
# Get database configuration from environment
|
||||
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
|
||||
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
|
||||
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
|
||||
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
|
||||
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
|
||||
|
||||
# Build connection string
|
||||
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
# Create async connection pool
|
||||
_connection_pool = AsyncConnectionPool(
|
||||
conninfo=connection_string,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
kwargs={"autocommit": True, "row_factory": dict_row},
|
||||
)
|
||||
|
||||
# Initialize the connection pool
|
||||
await _connection_pool.open()
|
||||
|
||||
# Create AsyncPostgresSaver with the pool
|
||||
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
|
||||
|
||||
# Setup the checkpointer (creates tables if needed)
|
||||
await _checkpointer_instance.setup()
|
||||
|
||||
logger.info("PostgreSQL checkpointer initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""Close the checkpointer and connection pool.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _connection_pool is not None:
|
||||
try:
|
||||
await _connection_pool.close()
|
||||
logger.info("PostgreSQL checkpointer connection pool closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
|
||||
|
||||
_checkpointer_instance = None
|
||||
_connection_pool = None
|
||||
|
||||
|
||||
def get_checkpointer() -> AsyncPostgresSaver:
|
||||
"""Get the global PostgreSQL checkpointer instance.
|
||||
|
||||
Returns:
|
||||
The initialized AsyncPostgresSaver instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If checkpointer is not initialized
|
||||
"""
|
||||
if _checkpointer_instance is None:
|
||||
raise RuntimeError(
|
||||
"PostgreSQL checkpointer not initialized. "
|
||||
"Call initialize_checkpointer() during application startup."
|
||||
)
|
||||
return _checkpointer_instance
|
||||
39
modules/features/chatBot/utils/permissions.py
Normal file
39
modules/features/chatBot/utils/permissions.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Mock permissions module for chatbot access control.
|
||||
|
||||
This module provides mock permission functions that will be replaced
|
||||
with actual database-driven permissions in the future.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
|
||||
|
||||
# TODO: Replace these mock implementations with actual database queries
|
||||
|
||||
|
||||
def get_chatbot_tools(*, user_id: str) -> list[str]:
|
||||
"""Get list of tool IDs that the chatbot can use for a given user."""
|
||||
registry = get_registry()
|
||||
return registry.list_tool_ids()
|
||||
|
||||
|
||||
def get_chatbot_model(*, user_id: str) -> str:
|
||||
"""Gets the chatbot model(s) a user is allowed to use."""
|
||||
return "claude_4_5"
|
||||
|
||||
|
||||
def get_system_prompt(*, user_id: str) -> str:
|
||||
"""Get the system prompt for a user's chatbot session.
|
||||
|
||||
This is a mock implementation that returns a generic prompt with today's date.
|
||||
In production, this will query the database for user-specific or role-specific prompts.
|
||||
|
||||
Args:
|
||||
user_id: The unique identifier of the user
|
||||
|
||||
Returns:
|
||||
The system prompt string with the current date
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
return f"You're a smart assistant. Today is {current_date}"
|
||||
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Tool registry for auto-discovering and managing chatbot tools.
|
||||
|
||||
This module provides a central registry that automatically discovers all tools
|
||||
in the chatbotTools directory structure and provides methods to query them.
|
||||
The registry is built in-memory at startup and does not require a database.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""Metadata about a discovered chatbot tool.
|
||||
|
||||
Attributes:
|
||||
tool_id: Unique identifier (e.g., 'shared.tavily_search')
|
||||
name: Function name of the tool
|
||||
category: Category of the tool ('shared' or 'customer')
|
||||
description: Tool description from docstring
|
||||
tool_instance: The actual LangChain tool instance
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
|
||||
tool_id: str
|
||||
name: str
|
||||
category: str
|
||||
description: str
|
||||
tool_instance: BaseTool
|
||||
module_path: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a pretty-printed string representation for logging."""
|
||||
return (
|
||||
f"ToolMetadata(\n"
|
||||
f" tool_id='{self.tool_id}',\n"
|
||||
f" name='{self.name}',\n"
|
||||
f" category='{self.category}',\n"
|
||||
f" description='{self.description}',\n"
|
||||
f" module_path='{self.module_path}'\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Central registry for all chatbot tools.
|
||||
|
||||
This class discovers and catalogs all tools decorated with @tool in the
|
||||
chatbotTools directory structure. Tools are automatically discovered at
|
||||
initialization by scanning the filesystem.
|
||||
|
||||
The registry provides methods to query tools by ID, category, or get all tools.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty tool registry."""
|
||||
self._tools: Dict[str, ToolMetadata] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Discover and register all tools from the chatbotTools directory.
|
||||
|
||||
This method scans both sharedTools and customerTools directories,
|
||||
imports all tool*.py modules, and extracts functions decorated with @tool.
|
||||
|
||||
This method is idempotent - calling it multiple times has no effect
|
||||
after the first initialization.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("Tool registry already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.info("Initializing tool registry...")
|
||||
|
||||
# Get base path to chatbotTools directory
|
||||
base_path = Path(__file__).parent.parent / "chatbotTools"
|
||||
|
||||
if not base_path.exists():
|
||||
logger.warning(f"chatbotTools directory not found at {base_path}")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
# Discover tools in each category
|
||||
self._discover_category(
|
||||
category_path=base_path / "sharedTools", category="shared"
|
||||
)
|
||||
self._discover_category(
|
||||
category_path=base_path / "customerTools", category="customer"
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
|
||||
|
||||
def _discover_category(self, *, category_path: Path, category: str) -> None:
|
||||
"""Discover all tools in a specific category directory.
|
||||
|
||||
Args:
|
||||
category_path: Path to the category directory (sharedTools or customerTools)
|
||||
category: Category name ('shared' or 'customer')
|
||||
"""
|
||||
if not category_path.exists():
|
||||
logger.warning(f"Category directory not found: {category_path}")
|
||||
return
|
||||
|
||||
logger.debug(f"Discovering tools in category: {category}")
|
||||
|
||||
# Find all tool*.py files (excluding __init__.py)
|
||||
tool_files = [
|
||||
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
|
||||
]
|
||||
|
||||
for tool_file in tool_files:
|
||||
self._import_and_register_tools(
|
||||
tool_file=tool_file, category=category, category_path=category_path
|
||||
)
|
||||
|
||||
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
|
||||
|
||||
def _import_and_register_tools(
|
||||
self, *, tool_file: Path, category: str, category_path: Path
|
||||
) -> None:
|
||||
"""Import a tool module and register all discovered tools.
|
||||
|
||||
Args:
|
||||
tool_file: Path to the tool Python file
|
||||
category: Category name ('shared' or 'customer')
|
||||
category_path: Path to the category directory
|
||||
"""
|
||||
# Construct module name
|
||||
module_name = (
|
||||
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Find all BaseTool instances in the module
|
||||
tools_found = 0
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if isinstance(obj, BaseTool):
|
||||
self._register_tool(
|
||||
tool_instance=obj,
|
||||
name=name,
|
||||
category=category,
|
||||
module_path=module_name,
|
||||
)
|
||||
tools_found += 1
|
||||
|
||||
if tools_found == 0:
|
||||
logger.warning(f"No tools found in {module_name}")
|
||||
else:
|
||||
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"Import error loading tools from {module_name}: {str(e)}. "
|
||||
f"This tool will not be available."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
|
||||
def _register_tool(
|
||||
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
|
||||
) -> None:
|
||||
"""Register a single tool in the registry.
|
||||
|
||||
Args:
|
||||
tool_instance: The LangChain tool instance
|
||||
name: Function name of the tool
|
||||
category: Category name ('shared' or 'customer')
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
tool_id = f"{category}.{name}"
|
||||
|
||||
# Check for duplicate tool IDs
|
||||
if tool_id in self._tools:
|
||||
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
|
||||
|
||||
metadata = ToolMetadata(
|
||||
tool_id=tool_id,
|
||||
name=name,
|
||||
category=category,
|
||||
description=tool_instance.description or "",
|
||||
tool_instance=tool_instance,
|
||||
module_path=module_path,
|
||||
)
|
||||
|
||||
self._tools[tool_id] = metadata
|
||||
logger.debug(f"Registered tool: {tool_id}")
|
||||
|
||||
def get_all_tools(self) -> List[ToolMetadata]:
|
||||
"""Get all registered tools.
|
||||
|
||||
Returns:
|
||||
List of all tool metadata objects
|
||||
"""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
|
||||
"""Get a specific tool by its ID.
|
||||
|
||||
Args:
|
||||
tool_id: The tool identifier (e.g., 'shared.tavily_search')
|
||||
|
||||
Returns:
|
||||
Tool metadata if found, None otherwise
|
||||
"""
|
||||
return self._tools.get(tool_id)
|
||||
|
||||
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
|
||||
"""Get all tools in a specific category.
|
||||
|
||||
Args:
|
||||
category: Category name ('shared' or 'customer')
|
||||
|
||||
Returns:
|
||||
List of tool metadata for the specified category
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.category == category]
|
||||
|
||||
def list_tool_ids(self) -> List[str]:
|
||||
"""Get a list of all registered tool IDs.
|
||||
|
||||
Returns:
|
||||
List of tool ID strings
|
||||
"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
|
||||
"""Get actual tool instances for a list of tool IDs.
|
||||
|
||||
This is useful for filtering tools based on user permissions.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs to retrieve
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances for the specified IDs
|
||||
"""
|
||||
instances = []
|
||||
for tool_id in tool_ids:
|
||||
metadata = self.get_tool(tool_id=tool_id)
|
||||
if metadata:
|
||||
instances.append(metadata.tool_instance)
|
||||
else:
|
||||
logger.warning(f"Tool ID not found in registry: {tool_id}")
|
||||
return instances
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the registry has been initialized.
|
||||
|
||||
Returns:
|
||||
True if initialized, False otherwise
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ToolRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global tool registry instance.
|
||||
|
||||
This function ensures the registry is initialized on first access.
|
||||
Subsequent calls return the same instance.
|
||||
|
||||
Returns:
|
||||
The global ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
|
||||
if _registry is None:
|
||||
_registry = ToolRegistry()
|
||||
|
||||
if not _registry.is_initialized:
|
||||
_registry.initialize()
|
||||
|
||||
return _registry
|
||||
|
||||
|
||||
def reinitialize_registry() -> ToolRegistry:
|
||||
"""Force reinitialize the tool registry.
|
||||
|
||||
This is useful for testing or when tools are added dynamically.
|
||||
|
||||
Returns:
|
||||
The reinitialized ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
_registry = ToolRegistry()
|
||||
_registry.initialize()
|
||||
return _registry
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -21,13 +21,14 @@ os.makedirs(staticFolder, exist_ok=True)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="",
|
||||
tags=["Administration"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
prefix="", tags=["Administration"], responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
# Mount static files
|
||||
router.mount("/static", StaticFiles(directory=str(staticFolder), html=True), name="static")
|
||||
router.mount(
|
||||
"/static", StaticFiles(directory=str(staticFolder), html=True), name="static"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@limiter.limit("30/minute")
|
||||
|
|
@ -36,31 +37,40 @@ async def root(request: Request) -> Dict[str, str]:
|
|||
# Validate required configuration values
|
||||
allowedOrigins = APP_CONFIG.get("APP_ALLOWED_ORIGINS")
|
||||
if not allowedOrigins:
|
||||
raise HTTPException(status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "online",
|
||||
"message": "Data Platform API is active",
|
||||
"allowedOrigins": f"Allowed origins are {allowedOrigins}"
|
||||
"allowedOrigins": f"Allowed origins are {allowedOrigins}",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/environment")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_environment(request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, str]:
|
||||
async def get_environment(
|
||||
request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Get environment configuration for frontend"""
|
||||
# Validate required configuration values
|
||||
apiBaseUrl = APP_CONFIG.get("APP_API_URL")
|
||||
if not apiBaseUrl:
|
||||
raise HTTPException(status_code=500, detail="APP_API_URL configuration is required")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_API_URL configuration is required"
|
||||
)
|
||||
|
||||
environment = APP_CONFIG.get("APP_ENV")
|
||||
if not environment:
|
||||
raise HTTPException(status_code=500, detail="APP_ENV configuration is required")
|
||||
|
||||
|
||||
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||
if not instanceLabel:
|
||||
raise HTTPException(status_code=500, detail="APP_ENV_LABEL configuration is required")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_ENV_LABEL configuration is required"
|
||||
)
|
||||
|
||||
return {
|
||||
"apiBaseUrl": apiBaseUrl,
|
||||
"environment": environment,
|
||||
|
|
@ -68,13 +78,17 @@ async def get_environment(request: Request, currentUser: Dict[str, Any] = Depend
|
|||
# Add other environment variables the frontend might need
|
||||
}
|
||||
|
||||
|
||||
@router.options("/{fullPath:path}")
|
||||
@limiter.limit("60/minute")
|
||||
async def options_route(request: Request, fullPath: str) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/favicon.ico")
|
||||
@limiter.limit("30/minute")
|
||||
async def favicon(request: Request) -> FileResponse:
|
||||
return FileResponse(str(staticFolder / "favicon.ico"), media_type="image/x-icon")
|
||||
|
||||
favicon_path = staticFolder / "favicon.ico"
|
||||
if not favicon_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Favicon not found")
|
||||
return FileResponse(str(favicon_path), media_type="image/x-icon")
|
||||
|
|
|
|||
655
modules/routes/routeChatbot.py
Normal file
655
modules/routes/routeChatbot.py
Normal file
|
|
@ -0,0 +1,655 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
from modules.features.chatBot.database import get_async_db_session
|
||||
from modules.features.chatBot.service import (
|
||||
get_or_create_thread_for_user,
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User, UserPrivilege
|
||||
from modules.datamodels.datamodelChatbot import (
|
||||
ChatMessageRequest,
|
||||
MessageItem,
|
||||
ChatMessageResponse,
|
||||
ThreadSummary,
|
||||
ThreadListResponse,
|
||||
ThreadDetail,
|
||||
RenameThreadRequest,
|
||||
DeleteResponse,
|
||||
ToolListResponse,
|
||||
ToolInfo,
|
||||
GrantToolRequest,
|
||||
GrantToolResponse,
|
||||
RevokeToolRequest,
|
||||
RevokeToolResponse,
|
||||
UpdateToolRequest,
|
||||
UpdateToolResponse,
|
||||
)
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.features.chatBot import service as chat_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/chatbot",
|
||||
tags=["Chatbot"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
# --- Actual endpoints for chatbot ---
|
||||
|
||||
|
||||
@router.post("/message/stream")
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message_stream(
|
||||
*,
|
||||
request: Request,
|
||||
message_request: ChatMessageRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Post a message to a chat thread with streaming progress updates.
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
chat_service.post_message_stream(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/message", response_model=ChatMessageResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message(
|
||||
*,
|
||||
request: Request,
|
||||
message_request: ChatMessageRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ChatMessageResponse:
|
||||
"""
|
||||
Post a message to a chat thread and get assistant response (non-streaming).
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
For streaming updates, use the /message/stream endpoint instead.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
||||
|
||||
response = await chat_service.post_message(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads", response_model=ThreadListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_threads(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadListResponse:
|
||||
"""
|
||||
Get all chat threads for the current user.
|
||||
"""
|
||||
try:
|
||||
# Get all threads for the current user
|
||||
threads = await chat_service.get_all_threads_for_user(
|
||||
user=currentUser, session=session
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved {len(threads)} threads")
|
||||
|
||||
return ThreadListResponse(threads=threads)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_thread_by_id(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadDetail:
|
||||
"""
|
||||
Get a specific chat thread with all its messages from LangGraph checkpointer.
|
||||
"""
|
||||
try:
|
||||
thread_detail = await chat_service.get_thread_detail_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
||||
return thread_detail
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def rename_thread(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
rename_request: RenameThreadRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Rename a chat thread.
|
||||
"""
|
||||
try:
|
||||
await chat_service.update_thread_name(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
new_thread_name=rename_request.new_name,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'"
|
||||
)
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully renamed",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_thread(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Delete a chat thread and all its associated data from both LangGraph and database.
|
||||
"""
|
||||
try:
|
||||
await chat_service.delete_thread_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} deleted thread {thread_id}")
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully deleted",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
# Tool Management Endpoints
|
||||
|
||||
|
||||
@router.get("/tools", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_tools(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all available chatbot tools.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can view tools",
|
||||
)
|
||||
|
||||
# Get all tools from service
|
||||
tools_data = await chat_service.get_all_tools(session=session)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved {len(tools)} tools")
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tools/grant", response_model=GrantToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def grant_tool_to_user(
|
||||
*,
|
||||
request: Request,
|
||||
grant_request: GrantToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> GrantToolResponse:
|
||||
"""
|
||||
Grant a tool to a user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can grant tools",
|
||||
)
|
||||
|
||||
# Grant the tool
|
||||
await chat_service.grant_tool_to_user(
|
||||
user_id=grant_request.user_id,
|
||||
tool_id=grant_request.tool_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}"
|
||||
)
|
||||
|
||||
return GrantToolResponse(
|
||||
message=f"Tool successfully granted to user {grant_request.user_id}",
|
||||
user_id=grant_request.user_id,
|
||||
tool_id=grant_request.tool_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tools/revoke", response_model=RevokeToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def revoke_tool_from_user(
|
||||
*,
|
||||
request: Request,
|
||||
revoke_request: RevokeToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> RevokeToolResponse:
|
||||
"""
|
||||
Revoke a tool from a user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can revoke tools",
|
||||
)
|
||||
|
||||
# Revoke the tool
|
||||
await chat_service.revoke_tool_from_user(
|
||||
user_id=revoke_request.user_id,
|
||||
tool_id=revoke_request.tool_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}"
|
||||
)
|
||||
|
||||
return RevokeToolResponse(
|
||||
message=f"Tool successfully revoked from user {revoke_request.user_id}",
|
||||
user_id=revoke_request.user_id,
|
||||
tool_id=revoke_request.tool_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_tool(
|
||||
*,
|
||||
request: Request,
|
||||
tool_id: str,
|
||||
update_request: UpdateToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> UpdateToolResponse:
|
||||
"""
|
||||
Update a tool's label and/or description.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can update tools",
|
||||
)
|
||||
|
||||
# Update the tool
|
||||
updated_fields = await chat_service.update_tool(
|
||||
tool_id=tool_id,
|
||||
label=update_request.label,
|
||||
description=update_request.description,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}"
|
||||
)
|
||||
|
||||
return UpdateToolResponse(
|
||||
message="Tool successfully updated",
|
||||
tool_id=tool_id,
|
||||
updated_fields=updated_fields,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/user/{user_id}", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_tools_for_specific_user(
|
||||
*,
|
||||
request: Request,
|
||||
user_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all tools granted to a specific user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can view user tools",
|
||||
)
|
||||
|
||||
# Get tools for the specified user
|
||||
tools_data = await chat_service.get_tools_for_user(
|
||||
user_id=user_id, session=session
|
||||
)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}"
|
||||
)
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/me", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_my_tools(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all tools the current user has access to.
|
||||
"""
|
||||
try:
|
||||
# Get tools for the current user
|
||||
tools_data = await chat_service.get_tools_for_user(
|
||||
user_id=currentUser.id, session=session
|
||||
)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} retrieved {len(tools)} tools for themselves"
|
||||
)
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
|
@ -2,52 +2,55 @@
|
|||
Shared utilities for model attributes and labels.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, Any, List, Type, Optional, Union
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ModelMixin:
|
||||
"""Mixin class that provides serialization methods for Pydantic models."""
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
Handles both Pydantic v1 and v2.
|
||||
All timestamp fields remain as float values.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the model
|
||||
"""
|
||||
# Get the raw dictionary
|
||||
if hasattr(self, 'model_dump'):
|
||||
if hasattr(self, "model_dump"):
|
||||
data: Dict[str, Any] = self.model_dump() # Pydantic v2
|
||||
else:
|
||||
data: Dict[str, Any] = self.dict() # Pydantic v1
|
||||
|
||||
|
||||
# All fields (including timestamps) remain in their original format
|
||||
# No conversions needed - timestamps are already float
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelMixin':
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ModelMixin":
|
||||
"""
|
||||
Create a Pydantic model instance from a dictionary.
|
||||
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the model data
|
||||
|
||||
|
||||
Returns:
|
||||
ModelMixin: New instance of the model class
|
||||
"""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
# Define the AttributeDefinition class here instead of importing it
|
||||
class AttributeDefinition(BaseModel, ModelMixin):
|
||||
"""Definition of a model attribute with its metadata."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
label: str
|
||||
|
|
@ -64,41 +67,47 @@ class AttributeDefinition(BaseModel, ModelMixin):
|
|||
order: int = 0
|
||||
placeholder: Optional[str] = None
|
||||
|
||||
|
||||
# Global registry for model labels
|
||||
MODEL_LABELS: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
||||
|
||||
def to_dict(model: BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
Handles both Pydantic v1 and v2.
|
||||
|
||||
|
||||
Args:
|
||||
model: The Pydantic model instance to convert
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the model
|
||||
"""
|
||||
if hasattr(model, 'model_dump'):
|
||||
if hasattr(model, "model_dump"):
|
||||
return model.model_dump() # Pydantic v2
|
||||
return model.dict() # Pydantic v1
|
||||
|
||||
|
||||
def from_dict(model_class: Type[BaseModel], data: Dict[str, Any]) -> BaseModel:
|
||||
"""
|
||||
Create a Pydantic model instance from a dictionary.
|
||||
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model class to instantiate
|
||||
data: Dictionary containing the model data
|
||||
|
||||
|
||||
Returns:
|
||||
BaseModel: New instance of the model class
|
||||
"""
|
||||
return model_class(**data)
|
||||
|
||||
def register_model_labels(model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]):
|
||||
|
||||
def register_model_labels(
|
||||
model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]
|
||||
):
|
||||
"""
|
||||
Register labels for a model's attributes and the model itself.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model class
|
||||
model_label: Dictionary mapping language codes to model labels
|
||||
|
|
@ -106,38 +115,37 @@ def register_model_labels(model_name: str, model_label: Dict[str, str], labels:
|
|||
labels: Dictionary mapping attribute names to their translations
|
||||
e.g. {"name": {"en": "Name", "fr": "Nom"}}
|
||||
"""
|
||||
MODEL_LABELS[model_name] = {
|
||||
"model": model_label,
|
||||
"attributes": labels
|
||||
}
|
||||
MODEL_LABELS[model_name] = {"model": model_label, "attributes": labels}
|
||||
|
||||
|
||||
def get_model_labels(model_name: str, language: str = "en") -> Dict[str, str]:
|
||||
"""
|
||||
Get labels for a model's attributes in the specified language.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model class
|
||||
language: Language code (default: "en")
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping attribute names to their labels in the specified language
|
||||
"""
|
||||
model_data = MODEL_LABELS.get(model_name, {})
|
||||
attribute_labels = model_data.get("attributes", {})
|
||||
|
||||
|
||||
return {
|
||||
attr: translations.get(language, translations.get("en", attr))
|
||||
for attr, translations in attribute_labels.items()
|
||||
}
|
||||
|
||||
|
||||
def get_model_label(model_name: str, language: str = "en") -> str:
|
||||
"""
|
||||
Get the label for a model in the specified language.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model class
|
||||
language: Language code (default: "en")
|
||||
|
||||
|
||||
Returns:
|
||||
Model label in the specified language, or model name if no label exists
|
||||
"""
|
||||
|
|
@ -145,173 +153,228 @@ def get_model_label(model_name: str, language: str = "en") -> str:
|
|||
model_label = model_data.get("model", {})
|
||||
return model_label.get(language, model_label.get("en", model_name))
|
||||
|
||||
def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguage: str = "en") -> Dict[str, Any]:
|
||||
|
||||
def getModelAttributeDefinitions(
|
||||
modelClass: Type[BaseModel] = None, userLanguage: str = "en"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get attribute definitions for a model class.
|
||||
|
||||
|
||||
Args:
|
||||
modelClass: The model class to get attributes for
|
||||
userLanguage: Language code for translations (default: "en")
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing model label and attribute definitions
|
||||
"""
|
||||
if not modelClass:
|
||||
return {}
|
||||
|
||||
|
||||
attributes = []
|
||||
model_name = modelClass.__name__
|
||||
labels = get_model_labels(model_name, userLanguage)
|
||||
model_label = get_model_label(model_name, userLanguage)
|
||||
|
||||
|
||||
# Handle both Pydantic v1 and v2
|
||||
if hasattr(modelClass, 'model_fields'): # Pydantic v2
|
||||
if hasattr(modelClass, "model_fields"): # Pydantic v2
|
||||
fields = modelClass.model_fields
|
||||
for name, field in fields.items():
|
||||
# Extract frontend metadata from field info
|
||||
field_info = field.field_info if hasattr(field, 'field_info') else None
|
||||
field_info = field.field_info if hasattr(field, "field_info") else None
|
||||
# Check both direct attributes and extra field for frontend metadata
|
||||
frontend_type = None
|
||||
frontend_readonly = False
|
||||
frontend_required = field.is_required()
|
||||
frontend_options = None
|
||||
|
||||
|
||||
if field_info:
|
||||
# Try direct attributes first
|
||||
frontend_type = getattr(field_info, 'frontend_type', None)
|
||||
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
|
||||
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
|
||||
frontend_options = getattr(field_info, 'frontend_options', None)
|
||||
|
||||
frontend_type = getattr(field_info, "frontend_type", None)
|
||||
frontend_readonly = getattr(field_info, "frontend_readonly", False)
|
||||
frontend_required = getattr(
|
||||
field_info, "frontend_required", frontend_required
|
||||
)
|
||||
frontend_options = getattr(field_info, "frontend_options", None)
|
||||
|
||||
# If not found, check extra field
|
||||
if hasattr(field_info, 'extra') and field_info.extra:
|
||||
if hasattr(field_info, "extra") and field_info.extra:
|
||||
if frontend_type is None:
|
||||
frontend_type = field_info.extra.get('frontend_type')
|
||||
frontend_type = field_info.extra.get("frontend_type")
|
||||
if not frontend_readonly:
|
||||
frontend_readonly = field_info.extra.get('frontend_readonly', False)
|
||||
if frontend_required == field.is_required(): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get('frontend_required', frontend_required)
|
||||
frontend_readonly = field_info.extra.get(
|
||||
"frontend_readonly", False
|
||||
)
|
||||
if (
|
||||
frontend_required == field.is_required()
|
||||
): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get(
|
||||
"frontend_required", frontend_required
|
||||
)
|
||||
if frontend_options is None:
|
||||
frontend_options = field_info.extra.get('frontend_options')
|
||||
|
||||
frontend_options = field_info.extra.get("frontend_options")
|
||||
|
||||
# Use frontend type if available, otherwise fall back to Python type
|
||||
field_type = frontend_type if frontend_type else (field.annotation.__name__ if hasattr(field.annotation, "__name__") else str(field.annotation))
|
||||
|
||||
attributes.append({
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.description if hasattr(field, "description") else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options
|
||||
})
|
||||
field_type = (
|
||||
frontend_type
|
||||
if frontend_type
|
||||
else (
|
||||
field.annotation.__name__
|
||||
if hasattr(field.annotation, "__name__")
|
||||
else str(field.annotation)
|
||||
)
|
||||
)
|
||||
|
||||
attributes.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.description
|
||||
if hasattr(field, "description")
|
||||
else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options,
|
||||
}
|
||||
)
|
||||
else: # Pydantic v1
|
||||
fields = modelClass.__fields__
|
||||
for name, field in fields.items():
|
||||
# Extract frontend metadata from field info
|
||||
field_info = field.field_info if hasattr(field, 'field_info') else None
|
||||
field_info = field.field_info if hasattr(field, "field_info") else None
|
||||
# Check both direct attributes and extra field for frontend metadata
|
||||
frontend_type = None
|
||||
frontend_readonly = False
|
||||
frontend_required = field.required
|
||||
frontend_options = None
|
||||
|
||||
|
||||
if field_info:
|
||||
# Try direct attributes first
|
||||
frontend_type = getattr(field_info, 'frontend_type', None)
|
||||
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
|
||||
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
|
||||
frontend_options = getattr(field_info, 'frontend_options', None)
|
||||
|
||||
frontend_type = getattr(field_info, "frontend_type", None)
|
||||
frontend_readonly = getattr(field_info, "frontend_readonly", False)
|
||||
frontend_required = getattr(
|
||||
field_info, "frontend_required", frontend_required
|
||||
)
|
||||
frontend_options = getattr(field_info, "frontend_options", None)
|
||||
|
||||
# If not found, check extra field
|
||||
if hasattr(field_info, 'extra') and field_info.extra:
|
||||
if hasattr(field_info, "extra") and field_info.extra:
|
||||
if frontend_type is None:
|
||||
frontend_type = field_info.extra.get('frontend_type')
|
||||
frontend_type = field_info.extra.get("frontend_type")
|
||||
if not frontend_readonly:
|
||||
frontend_readonly = field_info.extra.get('frontend_readonly', False)
|
||||
if frontend_required == field.required: # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get('frontend_required', frontend_required)
|
||||
frontend_readonly = field_info.extra.get(
|
||||
"frontend_readonly", False
|
||||
)
|
||||
if (
|
||||
frontend_required == field.required
|
||||
): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get(
|
||||
"frontend_required", frontend_required
|
||||
)
|
||||
if frontend_options is None:
|
||||
frontend_options = field_info.extra.get('frontend_options')
|
||||
|
||||
frontend_options = field_info.extra.get("frontend_options")
|
||||
|
||||
# Use frontend type if available, otherwise fall back to Python type
|
||||
field_type = frontend_type if frontend_type else (field.type_.__name__ if hasattr(field.type_, "__name__") else str(field.type_))
|
||||
|
||||
attributes.append({
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.field_info.description if hasattr(field.field_info, "description") else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options
|
||||
})
|
||||
|
||||
return {
|
||||
"model": model_label,
|
||||
"attributes": attributes
|
||||
}
|
||||
field_type = (
|
||||
frontend_type
|
||||
if frontend_type
|
||||
else (
|
||||
field.type_.__name__
|
||||
if hasattr(field.type_, "__name__")
|
||||
else str(field.type_)
|
||||
)
|
||||
)
|
||||
|
||||
attributes.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.field_info.description
|
||||
if hasattr(field.field_info, "description")
|
||||
else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options,
|
||||
}
|
||||
)
|
||||
|
||||
return {"model": model_label, "attributes": attributes}
|
||||
|
||||
|
||||
def getModelClasses() -> Dict[str, Type[BaseModel]]:
|
||||
"""
|
||||
Dynamically get all model classes from all model modules.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Type[BaseModel]]: Dictionary of model class names to their classes
|
||||
"""
|
||||
modelClasses = {}
|
||||
|
||||
|
||||
# Get the interfaces directory path
|
||||
interfaces_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'interfaces')
|
||||
|
||||
interfaces_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "interfaces"
|
||||
)
|
||||
|
||||
# Find all model files in interfaces directory
|
||||
for fileName in os.listdir(interfaces_dir):
|
||||
if fileName.endswith('Model.py'):
|
||||
if fileName.endswith("Model.py"):
|
||||
# Convert fileName to module name (e.g., gatewayModel.py -> gatewayModel)
|
||||
module_name = fileName[:-3]
|
||||
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(f'modules.interfaces.{module_name}')
|
||||
|
||||
module = importlib.import_module(f"modules.interfaces.{module_name}")
|
||||
|
||||
# Get all classes from the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, BaseModel)
|
||||
and obj != BaseModel
|
||||
):
|
||||
modelClasses[name] = obj
|
||||
|
||||
|
||||
# Also get models from datamodels directory
|
||||
datamodels_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'datamodels')
|
||||
|
||||
datamodels_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "datamodels"
|
||||
)
|
||||
|
||||
# Find all model files in datamodels directory
|
||||
for fileName in os.listdir(datamodels_dir):
|
||||
if fileName.startswith('datamodel') and fileName.endswith('.py'):
|
||||
if fileName.startswith("datamodel") and fileName.endswith(".py"):
|
||||
# Convert fileName to module name (e.g., datamodelUtils.py -> datamodelUtils)
|
||||
module_name = fileName[:-3]
|
||||
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(f'modules.datamodels.{module_name}')
|
||||
|
||||
module = importlib.import_module(f"modules.datamodels.{module_name}")
|
||||
|
||||
# Get all classes from the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, BaseModel)
|
||||
and obj != BaseModel
|
||||
):
|
||||
modelClasses[name] = obj
|
||||
|
||||
|
||||
return modelClasses
|
||||
|
||||
|
||||
class AttributeResponse(BaseModel):
|
||||
"""Response model for entity attributes"""
|
||||
|
||||
attributes: List[AttributeDefinition]
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"attributes": [
|
||||
{
|
||||
|
|
@ -322,8 +385,9 @@ class AttributeResponse(BaseModel):
|
|||
"placeholder": "Please enter username",
|
||||
"editable": True,
|
||||
"visible": True,
|
||||
"order": 0
|
||||
"order": 0,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
## Web Framework & API
|
||||
fastapi==0.104.1
|
||||
fastapi==0.115.0 # Upgraded for Pydantic v2 compatibility
|
||||
websockets==12.0
|
||||
uvicorn==0.23.2
|
||||
python-multipart==0.0.6
|
||||
httpx==0.25.0
|
||||
pydantic==1.10.13 # Ältere Version ohne Rust-Abhängigkeit
|
||||
httpx>=0.25.2
|
||||
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
|
||||
email-validator==2.0.0 # Required by Pydantic for email validation
|
||||
slowapi==0.1.8 # For rate limiting
|
||||
|
||||
|
|
@ -109,3 +109,14 @@ xyzservices>=2021.09.1
|
|||
|
||||
# PostgreSQL connector dependencies
|
||||
psycopg2-binary==2.9.9
|
||||
|
||||
## LangChain & LangGraph
|
||||
langchain==0.3.27
|
||||
langgraph==0.6.8
|
||||
langchain-core==0.3.77
|
||||
langchain-anthropic==0.3.1 # For Claude models
|
||||
psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer)
|
||||
psycopg-pool==3.2.1 # Connection pooling for PostgreSQL
|
||||
langgraph-checkpoint-postgres==2.0.24
|
||||
|
||||
greenlet==3.2.4
|
||||
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""Pytest tests for the tool registry.
|
||||
|
||||
This module tests that the tool registry correctly discovers and catalogs
|
||||
all tools in the chatbotTools directory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
from modules.features.chatBot.utils.toolRegistry import (
|
||||
ToolMetadata,
|
||||
ToolRegistry,
|
||||
get_registry,
|
||||
reinitialize_registry,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Test suite for ToolRegistry class."""
|
||||
|
||||
@pytest.fixture
|
||||
def registry(self) -> ToolRegistry:
|
||||
"""Provide a fresh registry instance for each test."""
|
||||
return reinitialize_registry()
|
||||
|
||||
def test_registry_initialization(self, registry: ToolRegistry) -> None:
|
||||
"""Test that registry initializes correctly."""
|
||||
assert registry.is_initialized
|
||||
assert isinstance(registry._tools, dict)
|
||||
|
||||
def test_get_all_tools(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all registered tools."""
|
||||
all_tools = registry.get_all_tools()
|
||||
assert isinstance(all_tools, list)
|
||||
assert len(all_tools) > 0
|
||||
assert all(isinstance(tool, ToolMetadata) for tool in all_tools)
|
||||
|
||||
# Log all discovered tools
|
||||
logger.info(f"Found {len(all_tools)} tools in registry:")
|
||||
for tool in all_tools:
|
||||
logger.info(f"\n{tool}")
|
||||
|
||||
def test_tool_metadata_structure(self, registry: ToolRegistry) -> None:
|
||||
"""Test that tool metadata has correct structure."""
|
||||
all_tools = registry.get_all_tools()
|
||||
for tool in all_tools:
|
||||
assert isinstance(tool.tool_id, str)
|
||||
assert isinstance(tool.name, str)
|
||||
assert isinstance(tool.category, str)
|
||||
assert tool.category in ["shared", "customer"]
|
||||
assert isinstance(tool.description, str)
|
||||
assert isinstance(tool.tool_instance, BaseTool)
|
||||
assert isinstance(tool.module_path, str)
|
||||
|
||||
def test_list_tool_ids(self, registry: ToolRegistry) -> None:
|
||||
"""Test listing all tool IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
assert isinstance(tool_ids, list)
|
||||
assert len(tool_ids) > 0
|
||||
assert all(isinstance(tool_id, str) for tool_id in tool_ids)
|
||||
|
||||
# Check that tool IDs follow expected format
|
||||
for tool_id in tool_ids:
|
||||
assert "." in tool_id
|
||||
category, name = tool_id.split(".", 1)
|
||||
assert category in ["shared", "customer"]
|
||||
|
||||
def test_get_specific_tool(self, registry: ToolRegistry) -> None:
|
||||
"""Test retrieving a specific tool by ID."""
|
||||
# Get all tool IDs first
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if tool_ids:
|
||||
# Test with first available tool
|
||||
test_tool_id = tool_ids[0]
|
||||
tool_metadata = registry.get_tool(tool_id=test_tool_id)
|
||||
|
||||
assert tool_metadata is not None
|
||||
assert isinstance(tool_metadata, ToolMetadata)
|
||||
assert tool_metadata.tool_id == test_tool_id
|
||||
|
||||
def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None:
|
||||
"""Test retrieving a tool that doesn't exist."""
|
||||
tool_metadata = registry.get_tool(tool_id="nonexistent.tool")
|
||||
assert tool_metadata is None
|
||||
|
||||
def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all shared tools."""
|
||||
shared_tools = registry.get_tools_by_category(category="shared")
|
||||
assert isinstance(shared_tools, list)
|
||||
assert all(tool.category == "shared" for tool in shared_tools)
|
||||
|
||||
def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all customer tools."""
|
||||
customer_tools = registry.get_tools_by_category(category="customer")
|
||||
assert isinstance(customer_tools, list)
|
||||
assert all(tool.category == "customer" for tool in customer_tools)
|
||||
|
||||
def test_get_tool_instances(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting tool instances by IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if len(tool_ids) >= 2:
|
||||
# Test with first two tools
|
||||
test_ids = tool_ids[:2]
|
||||
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||
|
||||
assert isinstance(instances, list)
|
||||
assert len(instances) == 2
|
||||
assert all(isinstance(inst, BaseTool) for inst in instances)
|
||||
|
||||
def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting tool instances with some invalid IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if tool_ids:
|
||||
# Mix valid and invalid IDs
|
||||
test_ids = [tool_ids[0], "invalid.tool"]
|
||||
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||
|
||||
# Should only return the valid one
|
||||
assert len(instances) == 1
|
||||
assert isinstance(instances[0], BaseTool)
|
||||
|
||||
def test_global_registry_singleton(self) -> None:
|
||||
"""Test that get_registry returns same instance."""
|
||||
registry1 = get_registry()
|
||||
registry2 = get_registry()
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_reinitialize_registry(self) -> None:
|
||||
"""Test that reinitialize creates new instance."""
|
||||
registry1 = get_registry()
|
||||
registry2 = reinitialize_registry()
|
||||
# Should be different instances after reinitialize
|
||||
assert registry1 is not registry2
|
||||
assert registry2.is_initialized
|
||||
|
||||
|
||||
class TestToolDiscovery:
|
||||
"""Test suite for tool discovery functionality."""
|
||||
|
||||
def test_discovers_at_least_one_tool(self) -> None:
|
||||
"""Test that at least one tool is discovered."""
|
||||
registry = get_registry()
|
||||
tool_ids = registry.list_tool_ids()
|
||||
|
||||
# At least one tool should be successfully loaded
|
||||
assert len(tool_ids) >= 1, "Expected at least one tool to be discovered"
|
||||
|
||||
def test_query_althaus_database_if_available(self) -> None:
|
||||
"""Test query_althaus_database tool if it was successfully loaded."""
|
||||
registry = get_registry()
|
||||
tool = registry.get_tool(tool_id="customer.query_althaus_database")
|
||||
|
||||
if tool is not None:
|
||||
assert tool.name == "query_althaus_database"
|
||||
assert tool.category == "customer"
|
||||
assert "database" in tool.description.lower()
|
||||
else:
|
||||
# Tool may not have loaded due to import errors - log warning
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"customer.query_althaus_database tool not found - "
|
||||
"may have failed to import"
|
||||
)
|
||||
|
||||
def test_tavily_search_if_available(self) -> None:
|
||||
"""Test tavily_search tool if it was successfully loaded."""
|
||||
registry = get_registry()
|
||||
tool = registry.get_tool(tool_id="shared.tavily_search")
|
||||
|
||||
if tool is not None:
|
||||
assert tool.name == "tavily_search"
|
||||
assert tool.category == "shared"
|
||||
assert "search" in tool.description.lower()
|
||||
else:
|
||||
# Tool may not have loaded due to import errors - log warning
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"shared.tavily_search tool not found - may have failed to import"
|
||||
)
|
||||
|
||||
def test_tool_ids_have_correct_format(self) -> None:
|
||||
"""Test that all discovered tool IDs follow the expected format."""
|
||||
registry = get_registry()
|
||||
tool_ids = registry.list_tool_ids()
|
||||
|
||||
for tool_id in tool_ids:
|
||||
# All tool IDs should have format: category.toolname
|
||||
assert "." in tool_id, f"Tool ID {tool_id} missing category separator"
|
||||
category, name = tool_id.split(".", 1)
|
||||
assert category in [
|
||||
"shared",
|
||||
"customer",
|
||||
], f"Tool {tool_id} has invalid category: {category}"
|
||||
assert len(name) > 0, f"Tool {tool_id} has empty name"
|
||||
Loading…
Reference in a new issue