Merge branch 'int' of https://github.com/valueonag/gateway into int
This commit is contained in:
commit
82b2fd36dc
26 changed files with 5672 additions and 838 deletions
236
app.py
236
app.py
|
|
@ -1,8 +1,20 @@
|
|||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
from urllib.parse import quote_plus
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from modules.features.chatBot.database import init_models as init_chatbot_models
|
||||
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = "12"
|
||||
|
||||
# Fix for Windows asyncio compatibility with psycopg
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, Body, status, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
||||
from fastapi.security import HTTPBearer
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
|
|
@ -21,7 +33,9 @@ class DailyRotatingFileHandler(RotatingFileHandler):
|
|||
The log file name includes the current date and switches at midnight.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs):
|
||||
def __init__(
|
||||
self, log_dir, filename_prefix, max_bytes=10485760, backup_count=5, **kwargs
|
||||
):
|
||||
self.log_dir = log_dir
|
||||
self.filename_prefix = filename_prefix
|
||||
self.current_date = None
|
||||
|
|
@ -31,7 +45,9 @@ class DailyRotatingFileHandler(RotatingFileHandler):
|
|||
self._update_file_if_needed()
|
||||
|
||||
# Call parent constructor with current file
|
||||
super().__init__(self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs)
|
||||
super().__init__(
|
||||
self.current_file, maxBytes=max_bytes, backupCount=backup_count, **kwargs
|
||||
)
|
||||
|
||||
def _update_file_if_needed(self):
|
||||
"""Update the log file if the date has changed"""
|
||||
|
|
@ -64,6 +80,7 @@ class DailyRotatingFileHandler(RotatingFileHandler):
|
|||
# Call parent emit method
|
||||
super().emit(record)
|
||||
|
||||
|
||||
def initLogging():
|
||||
"""Initialize logging with configuration from APP_CONFIG"""
|
||||
# Get log level from config (default to INFO if not found)
|
||||
|
|
@ -83,26 +100,32 @@ def initLogging():
|
|||
# Create formatters - using single line format
|
||||
consoleFormatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
# File formatter with more detailed error information but still single line
|
||||
fileFormatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s - %(pathname)s:%(lineno)d - %(funcName)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
# Add filter to exclude Chrome DevTools requests
|
||||
class ChromeDevToolsFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
return not (isinstance(record.msg, str) and
|
||||
('.well-known/appspecific/com.chrome.devtools.json' in record.msg or
|
||||
'Request: /index.html' in record.msg))
|
||||
return not (
|
||||
isinstance(record.msg, str)
|
||||
and (
|
||||
".well-known/appspecific/com.chrome.devtools.json" in record.msg
|
||||
or "Request: /index.html" in record.msg
|
||||
)
|
||||
)
|
||||
|
||||
# Add filter to exclude all httpcore loggers (including sub-loggers)
|
||||
class HttpcoreStarFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
return not (record.name == 'httpcore' or record.name.startswith('httpcore.'))
|
||||
return not (
|
||||
record.name == "httpcore" or record.name.startswith("httpcore.")
|
||||
)
|
||||
|
||||
# Add filter to exclude HTTP debug messages
|
||||
class HTTPDebugFilter(logging.Filter):
|
||||
|
|
@ -110,14 +133,14 @@ def initLogging():
|
|||
if isinstance(record.msg, str):
|
||||
# Filter out HTTP debug messages
|
||||
http_debug_patterns = [
|
||||
'receive_response_body.started',
|
||||
'receive_response_body.complete',
|
||||
'response_closed.started',
|
||||
'_send_single_request',
|
||||
'httpcore.http11',
|
||||
'httpx._client',
|
||||
'HTTP Request',
|
||||
'multipart.multipart'
|
||||
"receive_response_body.started",
|
||||
"receive_response_body.complete",
|
||||
"response_closed.started",
|
||||
"_send_single_request",
|
||||
"httpcore.http11",
|
||||
"httpx._client",
|
||||
"HTTP Request",
|
||||
"multipart.multipart",
|
||||
]
|
||||
return not any(pattern in record.msg for pattern in http_debug_patterns)
|
||||
return True
|
||||
|
|
@ -129,13 +152,21 @@ def initLogging():
|
|||
# Remove only emojis, preserve other Unicode characters like quotes
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
# Remove emoji characters specifically
|
||||
record.msg = ''.join(char for char in record.msg if unicodedata.category(char) != 'So' or not (0x1F600 <= ord(char) <= 0x1F64F or 0x1F300 <= ord(char) <= 0x1F5FF or 0x1F680 <= ord(char) <= 0x1F6FF or 0x1F1E0 <= ord(char) <= 0x1F1FF or 0x2600 <= ord(char) <= 0x26FF or 0x2700 <= ord(char) <= 0x27BF))
|
||||
# Additionally strip characters not representable in Windows cp1252 (e.g., arrows)
|
||||
try:
|
||||
record.msg.encode('cp1252', errors='strict')
|
||||
except UnicodeEncodeError:
|
||||
record.msg = record.msg.encode('cp1252', errors='ignore').decode('cp1252', errors='ignore')
|
||||
record.msg = "".join(
|
||||
char
|
||||
for char in record.msg
|
||||
if unicodedata.category(char) != "So"
|
||||
or not (
|
||||
0x1F600 <= ord(char) <= 0x1F64F
|
||||
or 0x1F300 <= ord(char) <= 0x1F5FF
|
||||
or 0x1F680 <= ord(char) <= 0x1F6FF
|
||||
or 0x1F1E0 <= ord(char) <= 0x1F1FF
|
||||
or 0x2600 <= ord(char) <= 0x26FF
|
||||
or 0x2700 <= ord(char) <= 0x27BF
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
# Configure handlers based on config
|
||||
|
|
@ -154,14 +185,16 @@ def initLogging():
|
|||
# Add file handler if enabled
|
||||
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
|
||||
# Create daily application log file with automatic date switching
|
||||
rotationSize = int(APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)) # Default: 10MB
|
||||
rotationSize = int(
|
||||
APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)
|
||||
) # Default: 10MB
|
||||
backupCount = int(APP_CONFIG.get("APP_LOGGING_BACKUP_COUNT", 5))
|
||||
|
||||
fileHandler = DailyRotatingFileHandler(
|
||||
log_dir=logDir,
|
||||
filename_prefix="log_app",
|
||||
max_bytes=rotationSize,
|
||||
backup_count=backupCount
|
||||
backup_count=backupCount,
|
||||
)
|
||||
fileHandler.setFormatter(fileFormatter)
|
||||
fileHandler.addFilter(ChromeDevToolsFilter())
|
||||
|
|
@ -176,11 +209,18 @@ def initLogging():
|
|||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
handlers=handlers,
|
||||
force=True # Force reconfiguration of the root logger
|
||||
force=True, # Force reconfiguration of the root logger
|
||||
)
|
||||
|
||||
# Silence noisy third-party libraries - use the same level as the root logger
|
||||
noisyLoggers = ["httpx", "httpcore", "urllib3", "asyncio", "fastapi.security.oauth2", "msal"]
|
||||
noisyLoggers = [
|
||||
"httpx",
|
||||
"httpcore",
|
||||
"urllib3",
|
||||
"asyncio",
|
||||
"fastapi.security.oauth2",
|
||||
"msal",
|
||||
]
|
||||
for loggerName in noisyLoggers:
|
||||
logging.getLogger(loggerName).setLevel(logging.WARNING)
|
||||
|
||||
|
|
@ -189,27 +229,88 @@ def initLogging():
|
|||
logger.info(f"Logging initialized with level {logLevelName}")
|
||||
logger.info(f"Log directory: {logDir}")
|
||||
|
||||
if APP_CONFIG.get('APP_LOGGING_FILE_ENABLED', True):
|
||||
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
appLogFile = os.path.join(logDir, f"log_app_{today}.log")
|
||||
logger.info(f"Application log file: {appLogFile} (auto-switches daily)")
|
||||
else:
|
||||
logger.info("Application log file: disabled")
|
||||
|
||||
logger.info(f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}")
|
||||
logger.info(
|
||||
f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}"
|
||||
)
|
||||
|
||||
|
||||
def make_sqlalchemy_db_url() -> str:
|
||||
host = APP_CONFIG.get("SQLALCHEMY_DB_HOST", "localhost")
|
||||
port = APP_CONFIG.get("SQLALCHEMY_DB_PORT", "5432")
|
||||
db = APP_CONFIG.get("SQLALCHEMY_DB_DATABASE", "project_gateway")
|
||||
user = APP_CONFIG.get("SQLALCHEMY_DB_USER", "postgres")
|
||||
pwd = quote_plus(APP_CONFIG.get("SQLALCHEMY_DB_PASSWORD_SECRET", ""))
|
||||
return f"postgresql+psycopg://{user}:{pwd}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
# Initialize logging
|
||||
initLogging()
|
||||
logger = logging.getLogger(__name__)
|
||||
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||
|
||||
|
||||
# Define lifespan context manager for application startup/shutdown events
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Application is starting up")
|
||||
|
||||
# --- Init SQLAlchemy ---
|
||||
|
||||
engine = create_async_engine(
|
||||
make_sqlalchemy_db_url(), pool_pre_ping=True, echo=False
|
||||
)
|
||||
SessionLocal = async_sessionmaker(
|
||||
engine, expire_on_commit=False, class_=AsyncSession
|
||||
)
|
||||
app.state.checkpoint_engine = engine
|
||||
app.state.checkpoint_sessionmaker = SessionLocal
|
||||
|
||||
# NOTE: Might need Alembic migrations in the future
|
||||
await init_chatbot_models(engine)
|
||||
|
||||
# --- Sync tools from registry to database ---
|
||||
from modules.features.chatBot.database import sync_tools_from_registry
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await sync_tools_from_registry(session)
|
||||
await session.commit()
|
||||
logger.info("Tools synced from registry to database")
|
||||
|
||||
# --- Initialize LangGraph checkpointer ---
|
||||
|
||||
from modules.features.chatBot.utils.checkpointer import (
|
||||
initialize_checkpointer,
|
||||
close_checkpointer,
|
||||
)
|
||||
|
||||
try:
|
||||
await initialize_checkpointer()
|
||||
logger.info("LangGraph checkpointer initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}")
|
||||
# Continue startup even if checkpointer fails to initialize
|
||||
|
||||
# --- Init Event Manager ---
|
||||
eventManager.start()
|
||||
|
||||
yield
|
||||
|
||||
# --- Cleanup Event Manager ---
|
||||
eventManager.stop()
|
||||
|
||||
# --- Cleanup LangGraph checkpointer ---
|
||||
await close_checkpointer()
|
||||
|
||||
# --- Cleanup SQLAlchemy ---
|
||||
await engine.dispose()
|
||||
|
||||
logger.info("Application has been shut down")
|
||||
|
||||
|
||||
|
|
@ -217,9 +318,52 @@ async def lifespan(app: FastAPI):
|
|||
app = FastAPI(
|
||||
title="PowerOn | Data Platform API",
|
||||
description=f"Backend API for the Multi-Agent Platform by ValueOn AG ({instanceLabel})",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
swagger_ui_init_oauth={
|
||||
"usePkceWithAuthorizationCodeGrant": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Configure OpenAPI security scheme for Swagger UI
|
||||
# This adds the "Authorize" button to the /docs page
|
||||
security_scheme = HTTPBearer()
|
||||
app.openapi_schema = None # Reset schema to regenerate with security
|
||||
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version="1.0.0",
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add security scheme definition
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
"BearerAuth": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
"description": "Enter your JWT token (obtained from login endpoint or browser cookies)",
|
||||
}
|
||||
}
|
||||
|
||||
# Apply security globally to all endpoints
|
||||
# Individual endpoints can override this if needed
|
||||
openapi_schema["security"] = [{"BearerAuth": []}]
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
|
||||
# Parse CORS origins from environment variable
|
||||
def get_allowed_origins():
|
||||
origins_str = APP_CONFIG.get("APP_ALLOWED_ORIGINS", "http://localhost:8080")
|
||||
|
|
@ -228,73 +372,99 @@ def get_allowed_origins():
|
|||
logger.info(f"CORS allowed origins: {origins}")
|
||||
return origins
|
||||
|
||||
|
||||
# CORS configuration using environment variables
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins= get_allowed_origins(),
|
||||
allow_origins=get_allowed_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
max_age=86400 # Increased caching for preflight requests
|
||||
max_age=86400, # Increased caching for preflight requests
|
||||
)
|
||||
|
||||
# CSRF protection middleware
|
||||
from modules.security.csrf import CSRFMiddleware
|
||||
from modules.security.tokenRefreshMiddleware import TokenRefreshMiddleware, ProactiveTokenRefreshMiddleware
|
||||
from modules.security.tokenRefreshMiddleware import (
|
||||
TokenRefreshMiddleware,
|
||||
ProactiveTokenRefreshMiddleware,
|
||||
)
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# Token refresh middleware (silent refresh for expired OAuth tokens)
|
||||
app.add_middleware(TokenRefreshMiddleware, enabled=True)
|
||||
|
||||
# Proactive token refresh middleware (refresh tokens before they expire)
|
||||
app.add_middleware(ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5)
|
||||
app.add_middleware(
|
||||
ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5
|
||||
)
|
||||
|
||||
# Run triggered features
|
||||
import modules.features.init
|
||||
|
||||
# Include all routers
|
||||
from modules.routes.routeAdmin import router as generalRouter
|
||||
|
||||
app.include_router(generalRouter)
|
||||
|
||||
from modules.routes.routeAttributes import router as attributesRouter
|
||||
|
||||
app.include_router(attributesRouter)
|
||||
|
||||
from modules.routes.routeDataMandates import router as mandateRouter
|
||||
|
||||
app.include_router(mandateRouter)
|
||||
|
||||
from modules.routes.routeDataUsers import router as userRouter
|
||||
|
||||
app.include_router(userRouter)
|
||||
|
||||
from modules.routes.routeDataFiles import router as fileRouter
|
||||
|
||||
app.include_router(fileRouter)
|
||||
|
||||
from modules.routes.routeDataNeutralization import router as neutralizationRouter
|
||||
|
||||
app.include_router(neutralizationRouter)
|
||||
|
||||
from modules.routes.routeDataPrompts import router as promptRouter
|
||||
|
||||
app.include_router(promptRouter)
|
||||
|
||||
from modules.routes.routeDataConnections import router as connectionsRouter
|
||||
|
||||
app.include_router(connectionsRouter)
|
||||
|
||||
from modules.routes.routeWorkflows import router as workflowRouter
|
||||
|
||||
app.include_router(workflowRouter)
|
||||
|
||||
from modules.routes.routeChatPlayground import router as chatPlaygroundRouter
|
||||
|
||||
app.include_router(chatPlaygroundRouter)
|
||||
|
||||
from modules.routes.routeSecurityLocal import router as localRouter
|
||||
|
||||
app.include_router(localRouter)
|
||||
|
||||
from modules.routes.routeSecurityMsft import router as msftRouter
|
||||
|
||||
app.include_router(msftRouter)
|
||||
|
||||
from modules.routes.routeSecurityGoogle import router as googleRouter
|
||||
|
||||
app.include_router(googleRouter)
|
||||
|
||||
from modules.routes.routeVoiceGoogle import router as voiceGoogleRouter
|
||||
|
||||
app.include_router(voiceGoogleRouter)
|
||||
|
||||
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
|
||||
|
||||
app.include_router(adminSecurityRouter)
|
||||
|
||||
from modules.routes.routeChatbot import router as chatbotRouter
|
||||
|
||||
app.include_router(chatbotRouter)
|
||||
|
|
|
|||
|
|
@ -18,60 +18,100 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# No mapping needed - table name = Pydantic model name exactly
|
||||
|
||||
|
||||
class SystemTable(BaseModel, ModelMixin):
|
||||
"""Data model for system table entries"""
|
||||
|
||||
table_name: str = Field(
|
||||
description="Name of the table",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
frontend_required=True,
|
||||
)
|
||||
initial_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Initial ID for the table",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
def _get_model_fields(model_class) -> Dict[str, str]:
|
||||
"""Get all fields from Pydantic model and map to SQL types."""
|
||||
if not hasattr(model_class, '__fields__'):
|
||||
# Pydantic v2 uses model_fields instead of __fields__
|
||||
if hasattr(model_class, "model_fields"):
|
||||
model_fields = model_class.model_fields
|
||||
elif hasattr(model_class, "__fields__"):
|
||||
model_fields = model_class.__fields__
|
||||
else:
|
||||
return {}
|
||||
|
||||
fields = {}
|
||||
for field_name, field_info in model_class.__fields__.items():
|
||||
field_type = field_info.type_
|
||||
for field_name, field_info in model_fields.items():
|
||||
# Pydantic v2 uses annotation instead of type_
|
||||
field_type = (
|
||||
field_info.annotation
|
||||
if hasattr(field_info, "annotation")
|
||||
else field_info.type_
|
||||
)
|
||||
|
||||
# Check for JSONB fields (Dict, List, or complex types)
|
||||
if (field_type == dict or
|
||||
field_type == list or
|
||||
(hasattr(field_type, '__origin__') and field_type.__origin__ in (dict, list)) or
|
||||
field_name in ['execParameters', 'expectedDocumentFormats', 'resultDocuments', 'logs', 'messages', 'stats', 'tasks']):
|
||||
fields[field_name] = 'JSONB'
|
||||
if (
|
||||
field_type == dict
|
||||
or field_type == list
|
||||
or (
|
||||
hasattr(field_type, "__origin__")
|
||||
and field_type.__origin__ in (dict, list)
|
||||
)
|
||||
or field_name
|
||||
in [
|
||||
"execParameters",
|
||||
"expectedDocumentFormats",
|
||||
"resultDocuments",
|
||||
"logs",
|
||||
"messages",
|
||||
"stats",
|
||||
"tasks",
|
||||
]
|
||||
):
|
||||
fields[field_name] = "JSONB"
|
||||
# Simple type mapping
|
||||
elif field_type in (str, type(None)) or (get_origin(field_type) is Union and type(None) in get_args(field_type)):
|
||||
fields[field_name] = 'TEXT'
|
||||
elif field_type in (str, type(None)) or (
|
||||
get_origin(field_type) is Union and type(None) in get_args(field_type)
|
||||
):
|
||||
fields[field_name] = "TEXT"
|
||||
elif field_type == int:
|
||||
fields[field_name] = 'INTEGER'
|
||||
fields[field_name] = "INTEGER"
|
||||
elif field_type == float:
|
||||
fields[field_name] = 'DOUBLE PRECISION'
|
||||
fields[field_name] = "DOUBLE PRECISION"
|
||||
elif field_type == bool:
|
||||
fields[field_name] = 'BOOLEAN'
|
||||
fields[field_name] = "BOOLEAN"
|
||||
else:
|
||||
fields[field_name] = 'TEXT' # Default to TEXT
|
||||
fields[field_name] = "TEXT" # Default to TEXT
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
# No caching needed with proper database
|
||||
|
||||
|
||||
class DatabaseConnector:
|
||||
"""
|
||||
A connector for PostgreSQL-based data storage.
|
||||
Provides generic database operations without user/mandate filtering.
|
||||
Uses PostgreSQL with JSONB columns for flexible data storage.
|
||||
"""
|
||||
def __init__(self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dbHost: str,
|
||||
dbDatabase: str,
|
||||
dbUser: str = None,
|
||||
dbPassword: str = None,
|
||||
dbPort: int = None,
|
||||
userId: str = None,
|
||||
):
|
||||
# Store the input parameters
|
||||
self.dbHost = dbHost
|
||||
self.dbDatabase = dbDatabase
|
||||
|
|
@ -95,7 +135,6 @@ class DatabaseConnector:
|
|||
self._systemTableName = "_system"
|
||||
self._initializeSystemTable()
|
||||
|
||||
|
||||
def initDbSystem(self):
|
||||
"""Initialize the database system - creates database and tables."""
|
||||
try:
|
||||
|
|
@ -123,13 +162,15 @@ class DatabaseConnector:
|
|||
database="postgres",
|
||||
user=self.dbUser,
|
||||
password=self.dbPassword,
|
||||
client_encoding='utf8'
|
||||
client_encoding="utf8",
|
||||
)
|
||||
conn.autocommit = True
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Check if database exists
|
||||
cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,))
|
||||
cursor.execute(
|
||||
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
|
||||
)
|
||||
exists = cursor.fetchone()
|
||||
|
||||
if not exists:
|
||||
|
|
@ -143,8 +184,9 @@ class DatabaseConnector:
|
|||
except Exception as e:
|
||||
logger.error(f"FATAL ERROR: Cannot create database: {e}")
|
||||
logger.error("Database connection failed - application cannot start")
|
||||
raise RuntimeError(f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}")
|
||||
|
||||
raise RuntimeError(
|
||||
f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}"
|
||||
)
|
||||
|
||||
def _create_tables(self):
|
||||
"""Create only the system table - application tables are created by interfaces."""
|
||||
|
|
@ -156,7 +198,7 @@ class DatabaseConnector:
|
|||
database=self.dbDatabase,
|
||||
user=self.dbUser,
|
||||
password=self.dbPassword,
|
||||
client_encoding='utf8'
|
||||
client_encoding="utf8",
|
||||
)
|
||||
conn.autocommit = True
|
||||
|
||||
|
|
@ -175,7 +217,9 @@ class DatabaseConnector:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
|
||||
logger.error("Database system table creation failed - application cannot start")
|
||||
logger.error(
|
||||
"Database system table creation failed - application cannot start"
|
||||
)
|
||||
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
|
||||
|
||||
def _connect(self):
|
||||
|
|
@ -188,8 +232,8 @@ class DatabaseConnector:
|
|||
database=self.dbDatabase,
|
||||
user=self.dbUser,
|
||||
password=self.dbPassword,
|
||||
client_encoding='utf8',
|
||||
cursor_factory=psycopg2.extras.RealDictCursor
|
||||
client_encoding="utf8",
|
||||
cursor_factory=psycopg2.extras.RealDictCursor,
|
||||
)
|
||||
self.connection.autocommit = False # Use transactions
|
||||
except Exception as e:
|
||||
|
|
@ -219,7 +263,7 @@ class DatabaseConnector:
|
|||
# Check if system table has any data
|
||||
cursor.execute('SELECT COUNT(*) FROM "_system"')
|
||||
row = cursor.fetchone()
|
||||
count = row['count'] if row else 0
|
||||
count = row["count"] if row else 0
|
||||
|
||||
self.connection.commit()
|
||||
except Exception as e:
|
||||
|
|
@ -236,7 +280,7 @@ class DatabaseConnector:
|
|||
|
||||
system_data = {}
|
||||
for row in rows:
|
||||
system_data[row['table_name']] = row['initial_id']
|
||||
system_data[row["table_name"]] = row["initial_id"]
|
||||
|
||||
return system_data
|
||||
except Exception as e:
|
||||
|
|
@ -252,10 +296,13 @@ class DatabaseConnector:
|
|||
|
||||
# Insert new data
|
||||
for table_name, initial_id in data.items():
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO "_system" ("table_name", "initial_id", "_modifiedAt")
|
||||
VALUES (%s, %s, %s)
|
||||
""", (table_name, initial_id, get_utc_timestamp()))
|
||||
""",
|
||||
(table_name, initial_id, get_utc_timestamp()),
|
||||
)
|
||||
|
||||
self.connection.commit()
|
||||
return True
|
||||
|
|
@ -271,8 +318,11 @@ class DatabaseConnector:
|
|||
|
||||
with self.connection.cursor() as cursor:
|
||||
# Check if system table exists
|
||||
cursor.execute("SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,))
|
||||
exists = cursor.fetchone()['count'] > 0
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
|
||||
(self._systemTableName,),
|
||||
)
|
||||
exists = cursor.fetchone()["count"] > 0
|
||||
|
||||
if not exists:
|
||||
# Create system table
|
||||
|
|
@ -287,14 +337,19 @@ class DatabaseConnector:
|
|||
logger.info("System table created successfully")
|
||||
else:
|
||||
# Check if we need to add missing columns to existing table
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = %s AND table_schema = 'public'
|
||||
""", (self._systemTableName,))
|
||||
existing_columns = [row['column_name'] for row in cursor.fetchall()]
|
||||
""",
|
||||
(self._systemTableName,),
|
||||
)
|
||||
existing_columns = [row["column_name"] for row in cursor.fetchall()]
|
||||
|
||||
if '_modifiedAt' not in existing_columns:
|
||||
cursor.execute(f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "_modifiedAt" DOUBLE PRECISION')
|
||||
if "_modifiedAt" not in existing_columns:
|
||||
cursor.execute(
|
||||
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "_modifiedAt" DOUBLE PRECISION'
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
@ -314,60 +369,81 @@ class DatabaseConnector:
|
|||
|
||||
with self.connection.cursor() as cursor:
|
||||
# Check if table exists by querying information_schema with case-insensitive search
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
|
||||
''', (table,))
|
||||
exists = cursor.fetchone()['count'] > 0
|
||||
""",
|
||||
(table,),
|
||||
)
|
||||
exists = cursor.fetchone()["count"] > 0
|
||||
|
||||
if not exists:
|
||||
# Create table from Pydantic model
|
||||
self._create_table_from_model(cursor, table, model_class)
|
||||
logger.info(f"Created table '{table}' with columns from Pydantic model")
|
||||
logger.info(
|
||||
f"Created table '{table}' with columns from Pydantic model"
|
||||
)
|
||||
else:
|
||||
# Table exists: ensure all columns from model are present (simple additive migration)
|
||||
try:
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
|
||||
""", (table,))
|
||||
existing_columns = {row['column_name'] for row in cursor.fetchall()}
|
||||
""",
|
||||
(table,),
|
||||
)
|
||||
existing_columns = {
|
||||
row["column_name"] for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
# Desired columns based on model
|
||||
model_fields = _get_model_fields(model_class)
|
||||
desired_columns = set(['id']) | set(model_fields.keys()) | {'_createdAt', '_modifiedAt', '_createdBy', '_modifiedBy'}
|
||||
desired_columns = (
|
||||
set(["id"])
|
||||
| set(model_fields.keys())
|
||||
| {"_createdAt", "_modifiedAt", "_createdBy", "_modifiedBy"}
|
||||
)
|
||||
|
||||
# Add missing columns
|
||||
for col in sorted(desired_columns - existing_columns):
|
||||
# Determine SQL type
|
||||
if col in ['id']:
|
||||
if col in ["id"]:
|
||||
continue # primary key exists already
|
||||
sql_type = model_fields.get(col)
|
||||
if col in ['_createdAt']:
|
||||
sql_type = 'DOUBLE PRECISION'
|
||||
elif col in ['_modifiedAt']:
|
||||
sql_type = 'DOUBLE PRECISION'
|
||||
elif col in ['_createdBy', '_modifiedBy']:
|
||||
sql_type = 'VARCHAR(255)'
|
||||
if col in ["_createdAt"]:
|
||||
sql_type = "DOUBLE PRECISION"
|
||||
elif col in ["_modifiedAt"]:
|
||||
sql_type = "DOUBLE PRECISION"
|
||||
elif col in ["_createdBy", "_modifiedBy"]:
|
||||
sql_type = "VARCHAR(255)"
|
||||
if not sql_type:
|
||||
sql_type = 'TEXT'
|
||||
sql_type = "TEXT"
|
||||
try:
|
||||
cursor.execute(f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}')
|
||||
logger.info(f"Added missing column '{col}' ({sql_type}) to '{table}'")
|
||||
cursor.execute(
|
||||
f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
|
||||
)
|
||||
logger.info(
|
||||
f"Added missing column '{col}' ({sql_type}) to '{table}'"
|
||||
)
|
||||
except Exception as add_err:
|
||||
logger.warning(f"Could not add column '{col}' to '{table}': {add_err}")
|
||||
logger.warning(
|
||||
f"Could not add column '{col}' to '{table}': {add_err}"
|
||||
)
|
||||
except Exception as ensure_err:
|
||||
logger.warning(f"Could not ensure columns for existing table '{table}': {ensure_err}")
|
||||
logger.warning(
|
||||
f"Could not ensure columns for existing table '{table}': {ensure_err}"
|
||||
)
|
||||
|
||||
self.connection.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring table {table} exists: {e}")
|
||||
if hasattr(self, 'connection') and self.connection:
|
||||
if hasattr(self, "connection") and self.connection:
|
||||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
|
||||
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
||||
"""Create table with columns matching Pydantic model fields."""
|
||||
fields = _get_model_fields(model_class)
|
||||
|
|
@ -375,16 +451,18 @@ class DatabaseConnector:
|
|||
# Build column definitions with quoted identifiers to preserve exact case
|
||||
columns = ['"id" VARCHAR(255) PRIMARY KEY']
|
||||
for field_name, sql_type in fields.items():
|
||||
if field_name != 'id': # Skip id, already defined
|
||||
if field_name != "id": # Skip id, already defined
|
||||
columns.append(f'"{field_name}" {sql_type}')
|
||||
|
||||
# Add metadata columns
|
||||
columns.extend([
|
||||
columns.extend(
|
||||
[
|
||||
'"_createdAt" DOUBLE PRECISION',
|
||||
'"_modifiedAt" DOUBLE PRECISION',
|
||||
'"_createdBy" VARCHAR(255)',
|
||||
'"_modifiedBy" VARCHAR(255)'
|
||||
])
|
||||
'"_modifiedBy" VARCHAR(255)',
|
||||
]
|
||||
)
|
||||
|
||||
# Create table
|
||||
sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})'
|
||||
|
|
@ -392,16 +470,27 @@ class DatabaseConnector:
|
|||
|
||||
# Create indexes for foreign keys
|
||||
for field_name in fields:
|
||||
if field_name.endswith('Id') and field_name != 'id':
|
||||
cursor.execute(f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")')
|
||||
if field_name.endswith("Id") and field_name != "id":
|
||||
cursor.execute(
|
||||
f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")'
|
||||
)
|
||||
|
||||
|
||||
def _save_record(self, cursor, table: str, recordId: str, record: Dict[str, Any], model_class: type) -> None:
|
||||
def _save_record(
|
||||
self,
|
||||
cursor,
|
||||
table: str,
|
||||
recordId: str,
|
||||
record: Dict[str, Any],
|
||||
model_class: type,
|
||||
) -> None:
|
||||
"""Save record to normalized table with explicit columns."""
|
||||
# Get columns from Pydantic model instead of database schema
|
||||
fields = _get_model_fields(model_class)
|
||||
columns = ['id'] + [field for field in fields.keys() if field != 'id'] + ['_createdAt', '_createdBy', '_modifiedAt', '_modifiedBy']
|
||||
|
||||
columns = (
|
||||
["id"]
|
||||
+ [field for field in fields.keys() if field != "id"]
|
||||
+ ["_createdAt", "_createdBy", "_modifiedAt", "_modifiedBy"]
|
||||
)
|
||||
|
||||
if not columns:
|
||||
logger.error(f"No columns found for table {table}")
|
||||
|
|
@ -411,7 +500,7 @@ class DatabaseConnector:
|
|||
filtered_record = {k: v for k, v in record.items() if k in columns}
|
||||
|
||||
# Ensure id is set
|
||||
filtered_record['id'] = recordId
|
||||
filtered_record["id"] = recordId
|
||||
|
||||
# Prepare values in the correct order
|
||||
values = []
|
||||
|
|
@ -419,7 +508,7 @@ class DatabaseConnector:
|
|||
value = filtered_record.get(col)
|
||||
|
||||
# Handle timestamp fields - store as Unix timestamps (floats) for consistency
|
||||
if col in ['_createdAt', '_modifiedAt'] and value is not None:
|
||||
if col in ["_createdAt", "_modifiedAt"] and value is not None:
|
||||
if isinstance(value, str):
|
||||
# Try to parse string as timestamp
|
||||
try:
|
||||
|
|
@ -428,12 +517,13 @@ class DatabaseConnector:
|
|||
pass # Keep as string if parsing fails
|
||||
|
||||
# Convert enum values to their string representation
|
||||
elif hasattr(value, 'value'):
|
||||
elif hasattr(value, "value"):
|
||||
value = value.value
|
||||
|
||||
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
|
||||
elif col in fields and fields[col] == 'JSONB' and value is not None:
|
||||
elif col in fields and fields[col] == "JSONB" and value is not None:
|
||||
import json
|
||||
|
||||
if isinstance(value, (dict, list)):
|
||||
# Convert Python objects to JSON string for PostgreSQL JSONB
|
||||
value = json.dumps(value)
|
||||
|
|
@ -453,11 +543,16 @@ class DatabaseConnector:
|
|||
|
||||
values.append(value)
|
||||
|
||||
|
||||
# Build INSERT/UPDATE with quoted identifiers
|
||||
col_names = ', '.join([f'"{col}"' for col in columns])
|
||||
placeholders = ', '.join(['%s'] * len(columns))
|
||||
updates = ', '.join([f'"{col}" = EXCLUDED."{col}"' for col in columns[1:] if col not in ['_createdAt', '_createdBy']])
|
||||
col_names = ", ".join([f'"{col}"' for col in columns])
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
updates = ", ".join(
|
||||
[
|
||||
f'"{col}" = EXCLUDED."{col}"'
|
||||
for col in columns[1:]
|
||||
if col not in ["_createdAt", "_createdBy"]
|
||||
]
|
||||
)
|
||||
|
||||
sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}'
|
||||
|
||||
|
|
@ -481,11 +576,15 @@ class DatabaseConnector:
|
|||
record = dict(row)
|
||||
fields = _get_model_fields(model_class)
|
||||
|
||||
|
||||
# Parse JSONB fields back to Python objects
|
||||
for field_name, field_type in fields.items():
|
||||
if field_type == 'JSONB' and field_name in record and record[field_name] is not None:
|
||||
if (
|
||||
field_type == "JSONB"
|
||||
and field_name in record
|
||||
and record[field_name] is not None
|
||||
):
|
||||
import json
|
||||
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
|
|
@ -498,7 +597,9 @@ class DatabaseConnector:
|
|||
record[field_name] = json.loads(str(record[field_name]))
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}")
|
||||
logger.warning(
|
||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
||||
)
|
||||
pass
|
||||
|
||||
return record
|
||||
|
|
@ -506,7 +607,9 @@ class DatabaseConnector:
|
|||
logger.error(f"Error loading record {recordId} from table {table}: {e}")
|
||||
return None
|
||||
|
||||
def _saveRecord(self, model_class: type, recordId: str, record: Dict[str, Any]) -> bool:
|
||||
def _saveRecord(
|
||||
self, model_class: type, recordId: str, record: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Saves a single record to the table."""
|
||||
table = model_class.__name__
|
||||
|
||||
|
|
@ -555,30 +658,43 @@ class DatabaseConnector:
|
|||
fields = _get_model_fields(model_class)
|
||||
for record in records:
|
||||
for field_name, field_type in fields.items():
|
||||
if field_type == 'JSONB' and field_name in record:
|
||||
if field_type == "JSONB" and field_name in record:
|
||||
if record[field_name] is None:
|
||||
# Convert None to appropriate default based on field name
|
||||
if field_name in ['logs', 'messages', 'tasks', 'expectedDocumentFormats', 'resultDocuments']:
|
||||
if field_name in [
|
||||
"logs",
|
||||
"messages",
|
||||
"tasks",
|
||||
"expectedDocumentFormats",
|
||||
"resultDocuments",
|
||||
]:
|
||||
record[field_name] = []
|
||||
elif field_name in ['execParameters', 'stats']:
|
||||
elif field_name in ["execParameters", "stats"]:
|
||||
record[field_name] = {}
|
||||
else:
|
||||
record[field_name] = None
|
||||
else:
|
||||
import json
|
||||
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
record[field_name] = json.loads(record[field_name])
|
||||
record[field_name] = json.loads(
|
||||
record[field_name]
|
||||
)
|
||||
elif isinstance(record[field_name], (dict, list)):
|
||||
# Already a Python object, keep as is
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON
|
||||
record[field_name] = json.loads(str(record[field_name]))
|
||||
record[field_name] = json.loads(
|
||||
str(record[field_name])
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}")
|
||||
logger.warning(
|
||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
||||
)
|
||||
pass
|
||||
|
||||
return records
|
||||
|
|
@ -586,8 +702,6 @@ class DatabaseConnector:
|
|||
logger.error(f"Error loading table {table}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def _registerInitialId(self, table: str, initialId: str) -> bool:
|
||||
"""Registers the initial ID for a table."""
|
||||
try:
|
||||
|
|
@ -602,13 +716,17 @@ class DatabaseConnector:
|
|||
else:
|
||||
# Check if the existing initial ID still exists in the table
|
||||
existingInitialId = systemData[table]
|
||||
records = self.getRecordset(model_class, recordFilter={"id": existingInitialId})
|
||||
records = self.getRecordset(
|
||||
model_class, recordFilter={"id": existingInitialId}
|
||||
)
|
||||
if not records:
|
||||
# The initial record no longer exists, update to the new one
|
||||
systemData[table] = initialId
|
||||
success = self._saveSystemTable(systemData)
|
||||
if success:
|
||||
logger.info(f"Initial ID updated from {existingInitialId} to {initialId} for table {table}")
|
||||
logger.info(
|
||||
f"Initial ID updated from {existingInitialId} to {initialId} for table {table}"
|
||||
)
|
||||
return success
|
||||
else:
|
||||
return True
|
||||
|
|
@ -625,7 +743,9 @@ class DatabaseConnector:
|
|||
del systemData[table]
|
||||
success = self._saveSystemTable(systemData)
|
||||
if success:
|
||||
logger.info(f"Initial ID for table {table} removed from system table")
|
||||
logger.info(
|
||||
f"Initial ID for table {table} removed from system table"
|
||||
)
|
||||
return success
|
||||
return True # If not present, this is not an error
|
||||
except Exception as e:
|
||||
|
|
@ -662,7 +782,7 @@ class DatabaseConnector:
|
|||
ORDER BY table_name
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
tables = [row['table_name'] for row in rows]
|
||||
tables = [row["table_name"] for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
|
||||
|
||||
|
|
@ -679,7 +799,9 @@ class DatabaseConnector:
|
|||
|
||||
return fields
|
||||
|
||||
def getSchema(self, model_class: type, language: str = None) -> Dict[str, Dict[str, Any]]:
|
||||
def getSchema(
|
||||
self, model_class: type, language: str = None
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a schema object for a table with data types and labels."""
|
||||
data = self._loadTable(model_class)
|
||||
|
||||
|
|
@ -694,14 +816,16 @@ class DatabaseConnector:
|
|||
dataType = type(value).__name__
|
||||
label = field
|
||||
|
||||
schema[field] = {
|
||||
"type": dataType,
|
||||
"label": label
|
||||
}
|
||||
schema[field] = {"type": dataType, "label": label}
|
||||
|
||||
return schema
|
||||
|
||||
def getRecordset(self, model_class: type, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
def getRecordset(
|
||||
self,
|
||||
model_class: type,
|
||||
fieldFilter: List[str] = None,
|
||||
recordFilter: Dict[str, Any] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Returns a list of records from a table, filtered by criteria."""
|
||||
table = model_class.__name__
|
||||
|
||||
|
|
@ -734,30 +858,43 @@ class DatabaseConnector:
|
|||
fields = _get_model_fields(model_class)
|
||||
for record in records:
|
||||
for field_name, field_type in fields.items():
|
||||
if field_type == 'JSONB' and field_name in record:
|
||||
if field_type == "JSONB" and field_name in record:
|
||||
if record[field_name] is None:
|
||||
# Convert None to appropriate default based on field name
|
||||
if field_name in ['logs', 'messages', 'tasks', 'expectedDocumentFormats', 'resultDocuments']:
|
||||
if field_name in [
|
||||
"logs",
|
||||
"messages",
|
||||
"tasks",
|
||||
"expectedDocumentFormats",
|
||||
"resultDocuments",
|
||||
]:
|
||||
record[field_name] = []
|
||||
elif field_name in ['execParameters', 'stats']:
|
||||
elif field_name in ["execParameters", "stats"]:
|
||||
record[field_name] = {}
|
||||
else:
|
||||
record[field_name] = None
|
||||
else:
|
||||
import json
|
||||
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
record[field_name] = json.loads(record[field_name])
|
||||
record[field_name] = json.loads(
|
||||
record[field_name]
|
||||
)
|
||||
elif isinstance(record[field_name], (dict, list)):
|
||||
# Already a Python object, keep as is
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON
|
||||
record[field_name] = json.loads(str(record[field_name]))
|
||||
record[field_name] = json.loads(
|
||||
str(record[field_name])
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}")
|
||||
logger.warning(
|
||||
f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}"
|
||||
)
|
||||
pass
|
||||
|
||||
# If fieldFilter is available, reduce the fields
|
||||
|
|
@ -776,7 +913,9 @@ class DatabaseConnector:
|
|||
logger.error(f"Error loading records from table {table}: {e}")
|
||||
return []
|
||||
|
||||
def recordCreate(self, model_class: type, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
def recordCreate(
|
||||
self, model_class: type, record: Union[Dict[str, Any], BaseModel]
|
||||
) -> Dict[str, Any]:
|
||||
"""Creates a new record in a table based on Pydantic model class."""
|
||||
# If record is a Pydantic model, convert to dict
|
||||
if isinstance(record, BaseModel):
|
||||
|
|
@ -803,7 +942,9 @@ class DatabaseConnector:
|
|||
|
||||
return record
|
||||
|
||||
def recordModify(self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
def recordModify(
|
||||
self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]
|
||||
) -> Dict[str, Any]:
|
||||
"""Modifies an existing record in a table based on Pydantic model class."""
|
||||
# Load existing record
|
||||
existingRecord = self._loadRecord(model_class, recordId)
|
||||
|
|
@ -821,8 +962,12 @@ class DatabaseConnector:
|
|||
|
||||
# CRITICAL: Ensure we never modify the ID
|
||||
if "id" in record and str(record["id"]) != recordId:
|
||||
logger.error(f"Attempted to modify record ID from {recordId} to {record['id']}")
|
||||
raise ValueError("Cannot modify record ID - it must match the provided recordId")
|
||||
logger.error(
|
||||
f"Attempted to modify record ID from {recordId} to {record['id']}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Cannot modify record ID - it must match the provided recordId"
|
||||
)
|
||||
|
||||
# Update existing record with new data
|
||||
existingRecord.update(record)
|
||||
|
|
@ -841,7 +986,9 @@ class DatabaseConnector:
|
|||
|
||||
with self.connection.cursor() as cursor:
|
||||
# Check if record exists
|
||||
cursor.execute(f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,))
|
||||
cursor.execute(
|
||||
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
return False
|
||||
|
||||
|
|
@ -849,7 +996,9 @@ class DatabaseConnector:
|
|||
initialId = self.getInitialId(model_class)
|
||||
if initialId is not None and initialId == recordId:
|
||||
self._removeInitialId(table)
|
||||
logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table")
|
||||
logger.info(
|
||||
f"Initial ID {recordId} for table {table} has been removed from the system table"
|
||||
)
|
||||
|
||||
# Delete the record
|
||||
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
|
||||
|
|
@ -864,7 +1013,6 @@ class DatabaseConnector:
|
|||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
|
||||
def getInitialId(self, model_class: type) -> Optional[str]:
|
||||
"""Returns the initial ID for a table."""
|
||||
table = model_class.__name__
|
||||
|
|
@ -874,7 +1022,11 @@ class DatabaseConnector:
|
|||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
if hasattr(self, 'connection') and self.connection and not self.connection.closed:
|
||||
if (
|
||||
hasattr(self, "connection")
|
||||
and self.connection
|
||||
and not self.connection.closed
|
||||
):
|
||||
self.connection.close()
|
||||
|
||||
def __del__(self):
|
||||
|
|
|
|||
|
|
@ -1,21 +1,33 @@
|
|||
"""Chat models: ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class ChatStat(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
|
||||
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: Optional[str] = Field(
|
||||
None, description="Foreign key to workflow (for workflow stats)"
|
||||
)
|
||||
messageId: Optional[str] = Field(
|
||||
None, description="Foreign key to message (for message stats)"
|
||||
)
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Processing time in seconds"
|
||||
)
|
||||
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
|
||||
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
||||
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
|
||||
successRate: Optional[float] = Field(None, description="Success rate of operations")
|
||||
errorCount: Optional[int] = Field(None, description="Number of errors encountered")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatStat",
|
||||
{"en": "Chat Statistics", "fr": "Statistiques de chat"},
|
||||
|
|
@ -32,15 +44,27 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatLog(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
message: str = Field(description="Log message")
|
||||
type: str = Field(description="Log type (info, warning, error, etc.)")
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)")
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the log entry was created (UTC timestamp in seconds)",
|
||||
)
|
||||
status: Optional[str] = Field(None, description="Status of the log entry")
|
||||
progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)")
|
||||
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
|
||||
progress: Optional[float] = Field(
|
||||
None, description="Progress indicator (0.0 to 1.0)"
|
||||
)
|
||||
performance: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Performance metrics"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatLog",
|
||||
{"en": "Chat Log", "fr": "Journal de chat"},
|
||||
|
|
@ -56,8 +80,11 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatDocument(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
messageId: str = Field(description="Foreign key to message")
|
||||
fileId: str = Field(description="Foreign key to file")
|
||||
fileName: str = Field(description="Name of the file")
|
||||
|
|
@ -66,7 +93,11 @@ class ChatDocument(BaseModel, ModelMixin):
|
|||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that created this document")
|
||||
actionId: Optional[str] = Field(
|
||||
None, description="ID of the action that created this document"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatDocument",
|
||||
{"en": "Chat Document", "fr": "Document de chat"},
|
||||
|
|
@ -84,17 +115,26 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ContentMetadata(BaseModel, ModelMixin):
|
||||
size: int = Field(description="Content size in bytes")
|
||||
pages: Optional[int] = Field(None, description="Number of pages for multi-page content")
|
||||
pages: Optional[int] = Field(
|
||||
None, description="Number of pages for multi-page content"
|
||||
)
|
||||
error: Optional[str] = Field(None, description="Processing error if any")
|
||||
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
|
||||
height: Optional[int] = Field(None, description="Height in pixels for images/videos")
|
||||
height: Optional[int] = Field(
|
||||
None, description="Height in pixels for images/videos"
|
||||
)
|
||||
colorMode: Optional[str] = Field(None, description="Color mode")
|
||||
fps: Optional[float] = Field(None, description="Frames per second for videos")
|
||||
durationSec: Optional[float] = Field(None, description="Duration in seconds for media")
|
||||
durationSec: Optional[float] = Field(
|
||||
None, description="Duration in seconds for media"
|
||||
)
|
||||
mimeType: str = Field(description="MIME type of the content")
|
||||
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ContentMetadata",
|
||||
{"en": "Content Metadata", "fr": "Métadonnées du contenu"},
|
||||
|
|
@ -112,10 +152,13 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ContentItem(BaseModel, ModelMixin):
|
||||
label: str = Field(description="Content label")
|
||||
data: str = Field(description="Extracted text content")
|
||||
metadata: ContentMetadata = Field(description="Content metadata")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ContentItem",
|
||||
{"en": "Content Item", "fr": "Élément de contenu"},
|
||||
|
|
@ -126,9 +169,14 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatContentExtracted(BaseModel, ModelMixin):
|
||||
id: str = Field(description="Reference to source ChatDocument")
|
||||
contents: List[ContentItem] = Field(default_factory=list, description="List of content items")
|
||||
contents: List[ContentItem] = Field(
|
||||
default_factory=list, description="List of content items"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatContentExtracted",
|
||||
{"en": "Extracted Content", "fr": "Contenu extrait"},
|
||||
|
|
@ -138,28 +186,58 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading")
|
||||
documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents")
|
||||
documentsLabel: Optional[str] = Field(None, description="Label for the set of documents")
|
||||
parentMessageId: Optional[str] = Field(
|
||||
None, description="Parent message ID for threading"
|
||||
)
|
||||
documents: List[ChatDocument] = Field(
|
||||
default_factory=list, description="Associated documents"
|
||||
)
|
||||
documentsLabel: Optional[str] = Field(
|
||||
None, description="Label for the set of documents"
|
||||
)
|
||||
message: Optional[str] = Field(None, description="Message content")
|
||||
summary: Optional[str] = Field(None, description="Short summary of this message for planning/history")
|
||||
summary: Optional[str] = Field(
|
||||
None, description="Short summary of this message for planning/history"
|
||||
)
|
||||
role: str = Field(description="Role of the message sender")
|
||||
status: str = Field(description="Status of the message (first, step, last)")
|
||||
sequenceNr: int = Field(description="Sequence number of the message (set automatically)")
|
||||
publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)")
|
||||
sequenceNr: int = Field(
|
||||
description="Sequence number of the message (set automatically)"
|
||||
)
|
||||
publishedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the message was published (UTC timestamp in seconds)",
|
||||
)
|
||||
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
|
||||
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that produced this message")
|
||||
actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message")
|
||||
actionName: Optional[str] = Field(None, description="Name of the action that produced this message")
|
||||
success: Optional[bool] = Field(
|
||||
None, description="Whether the message processing was successful"
|
||||
)
|
||||
actionId: Optional[str] = Field(
|
||||
None, description="ID of the action that produced this message"
|
||||
)
|
||||
actionMethod: Optional[str] = Field(
|
||||
None, description="Method of the action that produced this message"
|
||||
)
|
||||
actionName: Optional[str] = Field(
|
||||
None, description="Name of the action that produced this message"
|
||||
)
|
||||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry")
|
||||
actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail")
|
||||
taskProgress: Optional[str] = Field(
|
||||
None, description="Task progress status: pending, running, success, fail, retry"
|
||||
)
|
||||
actionProgress: Optional[str] = Field(
|
||||
None, description="Action progress status: pending, running, success, fail"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessage",
|
||||
{"en": "Chat Message", "fr": "Message de chat"},
|
||||
|
|
@ -188,32 +266,139 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatWorkflow(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this workflow belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
status: str = Field(
|
||||
description="Current status of the workflow",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
||||
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
||||
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
|
||||
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
|
||||
])
|
||||
name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
||||
currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
||||
{"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}},
|
||||
],
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True,
|
||||
)
|
||||
currentRound: int = Field(
|
||||
description="Current round number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
currentTask: int = Field(
|
||||
default=0,
|
||||
description="Current task number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
currentAction: int = Field(
|
||||
default=0,
|
||||
description="Current action number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
totalTasks: int = Field(
|
||||
default=0,
|
||||
description="Total number of tasks in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
totalActions: int = Field(
|
||||
default=0,
|
||||
description="Total number of actions in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
lastActivity: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Timestamp of last activity (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
startedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the workflow started (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
logs: List[ChatLog] = Field(
|
||||
default_factory=list,
|
||||
description="Workflow logs",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
messages: List[ChatMessage] = Field(
|
||||
default_factory=list,
|
||||
description="Messages in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
stats: Optional[ChatStat] = Field(
|
||||
None,
|
||||
description="Workflow statistics",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
tasks: list = Field(
|
||||
default_factory=list,
|
||||
description="List of tasks in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
workflowMode: str = Field(
|
||||
default="Actionplan",
|
||||
description="Workflow mode selector",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{
|
||||
"value": "Actionplan",
|
||||
"label": {"en": "Action Plan", "fr": "Plan d'actions"},
|
||||
},
|
||||
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
|
||||
])
|
||||
maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False)
|
||||
],
|
||||
)
|
||||
maxSteps: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of iterations in react mode",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ChatWorkflow",
|
||||
{"en": "Chat Workflow", "fr": "Flux de travail de chat"},
|
||||
|
|
@ -238,12 +423,16 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class UserInputRequest(BaseModel, ModelMixin):
|
||||
prompt: str = Field(description="Prompt for the user")
|
||||
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
|
||||
userLanguage: str = Field(default="en", description="User's preferred language")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"UserInputRequest", {"en": "User Input Request", "fr": "Demande de saisie utilisateur"},
|
||||
"UserInputRequest",
|
||||
{"en": "User Input Request", "fr": "Demande de saisie utilisateur"},
|
||||
{
|
||||
"prompt": {"en": "Prompt", "fr": "Invite"},
|
||||
"listFileId": {"en": "File IDs", "fr": "IDs des fichiers"},
|
||||
|
|
@ -251,11 +440,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionDocument(BaseModel, ModelMixin):
|
||||
"""Clear document structure for action results"""
|
||||
|
||||
documentName: str = Field(description="Name of the document")
|
||||
documentData: Any = Field(description="Content/data of the document")
|
||||
mimeType: str = Field(description="MIME type of the document")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionDocument",
|
||||
{"en": "Action Document", "fr": "Document d'action"},
|
||||
|
|
@ -266,6 +459,7 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionResult(BaseModel, ModelMixin):
|
||||
"""Clean action result with documents as primary output
|
||||
|
||||
|
|
@ -276,16 +470,25 @@ class ActionResult(BaseModel, ModelMixin):
|
|||
|
||||
success: bool = Field(description="Whether execution succeeded")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs")
|
||||
resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)")
|
||||
documents: List[ActionDocument] = Field(
|
||||
default_factory=list, description="Document outputs"
|
||||
)
|
||||
resultLabel: Optional[str] = Field(
|
||||
None,
|
||||
description="Label for document routing (set by action handler, not by action methods)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def isSuccess(cls, documents: List[ActionDocument] = None) -> "ActionResult":
|
||||
return cls(success=True, documents=documents or [])
|
||||
|
||||
@classmethod
|
||||
def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> "ActionResult":
|
||||
def isFailure(
|
||||
cls, error: str, documents: List[ActionDocument] = None
|
||||
) -> "ActionResult":
|
||||
return cls(success=False, documents=documents or [], error=error)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionResult",
|
||||
{"en": "Action Result", "fr": "Résultat de l'action"},
|
||||
|
|
@ -297,9 +500,14 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionSelection(BaseModel, ModelMixin):
|
||||
method: str = Field(description="Method to execute (e.g., web, document, ai)")
|
||||
name: str = Field(description="Action name within the method (e.g., search, extract)")
|
||||
name: str = Field(
|
||||
description="Action name within the method (e.g., search, extract)"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionSelection",
|
||||
{"en": "Action Selection", "fr": "Sélection d'action"},
|
||||
|
|
@ -309,8 +517,13 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionParameters(BaseModel, ModelMixin):
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to execute the selected action")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Parameters to execute the selected action"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionParameters",
|
||||
{"en": "Action Parameters", "fr": "Paramètres d'action"},
|
||||
|
|
@ -319,10 +532,13 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ObservationPreview(BaseModel, ModelMixin):
|
||||
name: str = Field(description="Document name or URL label")
|
||||
mime: str = Field(description="MIME type or kind")
|
||||
snippet: str = Field(description="Short snippet or summary")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ObservationPreview",
|
||||
{"en": "Observation Preview", "fr": "Aperçu d'observation"},
|
||||
|
|
@ -333,12 +549,19 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class Observation(BaseModel, ModelMixin):
|
||||
success: bool = Field(description="Action execution success flag")
|
||||
resultLabel: str = Field(description="Deterministic label for produced documents")
|
||||
documentsCount: int = Field(description="Number of produced documents")
|
||||
previews: List[ObservationPreview] = Field(default_factory=list, description="Compact previews of outputs")
|
||||
notes: List[str] = Field(default_factory=list, description="Short notes or key facts")
|
||||
previews: List[ObservationPreview] = Field(
|
||||
default_factory=list, description="Compact previews of outputs"
|
||||
)
|
||||
notes: List[str] = Field(
|
||||
default_factory=list, description="Short notes or key facts"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"Observation",
|
||||
{"en": "Observation", "fr": "Observation"},
|
||||
|
|
@ -351,12 +574,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
class TaskStatus(str):
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskStatus",
|
||||
{"en": "Task Status", "fr": "Statut de la tâche"},
|
||||
|
|
@ -369,9 +595,14 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class DocumentExchange(BaseModel, ModelMixin):
|
||||
documentsLabel: str = Field(description="Label for the set of documents")
|
||||
documents: List[str] = Field(default_factory=list, description="List of document references")
|
||||
documents: List[str] = Field(
|
||||
default_factory=list, description="List of document references"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"DocumentExchange",
|
||||
{"en": "Document Exchange", "fr": "Échange de documents"},
|
||||
|
|
@ -381,20 +612,33 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ActionItem(BaseModel, ModelMixin):
|
||||
id: str = Field(..., description="Action ID")
|
||||
execMethod: str = Field(..., description="Method to execute")
|
||||
execAction: str = Field(..., description="Action to perform")
|
||||
execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters")
|
||||
execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents")
|
||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)")
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
execParameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Action parameters"
|
||||
)
|
||||
execResultLabel: Optional[str] = Field(
|
||||
None, description="Label for the set of result documents"
|
||||
)
|
||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(
|
||||
None, description="Expected document formats (optional)"
|
||||
)
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status")
|
||||
error: Optional[str] = Field(None, description="Error message if action failed")
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
timestamp: float = Field(..., description="When the action was executed (UTC timestamp in seconds)")
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Processing time in seconds"
|
||||
)
|
||||
timestamp: float = Field(
|
||||
..., description="When the action was executed (UTC timestamp in seconds)"
|
||||
)
|
||||
result: Optional[str] = Field(None, description="Result of the action")
|
||||
|
||||
def setSuccess(self, result: str = None) -> None:
|
||||
|
|
@ -408,6 +652,8 @@ class ActionItem(BaseModel, ModelMixin):
|
|||
"""Set the action as failed with error message"""
|
||||
self.status = TaskStatus.FAILED
|
||||
self.error = error_message
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ActionItem",
|
||||
{"en": "Task Action", "fr": "Action de tâche"},
|
||||
|
|
@ -417,7 +663,10 @@ register_model_labels(
|
|||
"execAction": {"en": "Action", "fr": "Action"},
|
||||
"execParameters": {"en": "Parameters", "fr": "Paramètres"},
|
||||
"execResultLabel": {"en": "Result Label", "fr": "Label du résultat"},
|
||||
"expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"},
|
||||
"expectedDocumentFormats": {
|
||||
"en": "Expected Document Formats",
|
||||
"fr": "Formats de documents attendus",
|
||||
},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
|
|
@ -429,12 +678,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskResult(BaseModel, ModelMixin):
|
||||
taskId: str = Field(..., description="Task ID")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||
success: bool = Field(..., description="Whether the task was successful")
|
||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskResult",
|
||||
{"en": "Task Result", "fr": "Résultat de tâche"},
|
||||
|
|
@ -447,22 +699,39 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskItem(BaseModel, ModelMixin):
|
||||
id: str = Field(..., description="Task ID")
|
||||
workflowId: str = Field(..., description="Workflow ID")
|
||||
userInput: str = Field(..., description="User input that triggered the task")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||
startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)")
|
||||
finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)")
|
||||
actionList: List[ActionItem] = Field(default_factory=list, description="List of actions to execute")
|
||||
startedAt: Optional[float] = Field(
|
||||
None, description="When the task started (UTC timestamp in seconds)"
|
||||
)
|
||||
finishedAt: Optional[float] = Field(
|
||||
None, description="When the task finished (UTC timestamp in seconds)"
|
||||
)
|
||||
actionList: List[ActionItem] = Field(
|
||||
default_factory=list, description="List of actions to execute"
|
||||
)
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure")
|
||||
dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on")
|
||||
rollbackOnFailure: bool = Field(
|
||||
default=True, description="Whether to rollback on failure"
|
||||
)
|
||||
dependencies: List[str] = Field(
|
||||
default_factory=list, description="List of task IDs this task depends on"
|
||||
)
|
||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||
processingTime: Optional[float] = Field(None, description="Total processing time in seconds")
|
||||
resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values")
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Total processing time in seconds"
|
||||
)
|
||||
resultLabels: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Map of result labels to their values"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskItem",
|
||||
{"en": "Task", "fr": "Tâche"},
|
||||
|
|
@ -481,13 +750,18 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskStep(BaseModel, ModelMixin):
|
||||
id: str
|
||||
objective: str
|
||||
dependencies: Optional[list[str]] = Field(default_factory=list)
|
||||
success_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
estimated_complexity: Optional[str] = None
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskStep",
|
||||
{"en": "Task Step", "fr": "Étape de tâche"},
|
||||
|
|
@ -496,23 +770,45 @@ register_model_labels(
|
|||
"objective": {"en": "Objective", "fr": "Objectif"},
|
||||
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
|
||||
"success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"},
|
||||
"estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"},
|
||||
"estimated_complexity": {
|
||||
"en": "Estimated Complexity",
|
||||
"fr": "Complexité estimée",
|
||||
},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskHandover(BaseModel, ModelMixin):
|
||||
taskId: str = Field(description="Target task ID")
|
||||
sourceTask: Optional[str] = Field(None, description="Source task ID")
|
||||
inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents")
|
||||
outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents")
|
||||
inputDocuments: List[DocumentExchange] = Field(
|
||||
default_factory=list, description="Available input documents"
|
||||
)
|
||||
outputDocuments: List[DocumentExchange] = Field(
|
||||
default_factory=list, description="Produced output documents"
|
||||
)
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="Task context")
|
||||
previousResults: List[str] = Field(default_factory=list, description="Previous result summaries")
|
||||
improvements: List[str] = Field(default_factory=list, description="Improvement suggestions")
|
||||
workflowSummary: Optional[str] = Field(None, description="Summarized workflow context")
|
||||
messageHistory: List[str] = Field(default_factory=list, description="Key message summaries")
|
||||
timestamp: float = Field(..., description="When the handover was created (UTC timestamp in seconds)")
|
||||
handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow")
|
||||
previousResults: List[str] = Field(
|
||||
default_factory=list, description="Previous result summaries"
|
||||
)
|
||||
improvements: List[str] = Field(
|
||||
default_factory=list, description="Improvement suggestions"
|
||||
)
|
||||
workflowSummary: Optional[str] = Field(
|
||||
None, description="Summarized workflow context"
|
||||
)
|
||||
messageHistory: List[str] = Field(
|
||||
default_factory=list, description="Key message summaries"
|
||||
)
|
||||
timestamp: float = Field(
|
||||
..., description="When the handover was created (UTC timestamp in seconds)"
|
||||
)
|
||||
handoverType: str = Field(
|
||||
default="task", description="Type of handover: task, phase, or workflow"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskHandover",
|
||||
{"en": "Task Handover", "fr": "Transfert de tâche"},
|
||||
|
|
@ -531,9 +827,10 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
workflow: Optional['ChatWorkflow'] = None
|
||||
workflow: Optional["ChatWorkflow"] = None
|
||||
workflow_id: Optional[str] = None
|
||||
available_documents: Optional[str] = "No documents available"
|
||||
available_connections: Optional[list[str]] = Field(default_factory=list)
|
||||
|
|
@ -562,6 +859,7 @@ class TaskContext(BaseModel, ModelMixin):
|
|||
self.improvements = []
|
||||
self.improvements.append(improvement)
|
||||
|
||||
|
||||
class ReviewContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
task_actions: Optional[list] = Field(default_factory=list)
|
||||
|
|
@ -570,6 +868,7 @@ class ReviewContext(BaseModel, ModelMixin):
|
|||
workflow_id: Optional[str] = None
|
||||
previous_results: Optional[list[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ReviewResult(BaseModel, ModelMixin):
|
||||
status: str
|
||||
reason: Optional[str] = None
|
||||
|
|
@ -579,7 +878,11 @@ class ReviewResult(BaseModel, ModelMixin):
|
|||
met_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
confidence: Optional[float] = 0.5
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="User-friendly message in user's language"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"ReviewResult",
|
||||
{"en": "Review Result", "fr": "Résultat de l'évaluation"},
|
||||
|
|
@ -596,10 +899,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class TaskPlan(BaseModel, ModelMixin):
|
||||
overview: str
|
||||
tasks: list[TaskStep]
|
||||
userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan")
|
||||
userMessage: Optional[str] = Field(
|
||||
None, description="Overall user-friendly message for the task plan"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"TaskPlan",
|
||||
{"en": "Task Plan", "fr": "Plan de tâches"},
|
||||
|
|
@ -613,10 +921,16 @@ register_model_labels(
|
|||
# Resolve forward references
|
||||
TaskContext.update_forward_refs()
|
||||
|
||||
|
||||
class PromptPlaceholder(BaseModel, ModelMixin):
|
||||
label: str
|
||||
content: str
|
||||
summaryAllowed: bool = Field(default=False, description="Whether host may summarize content before sending to AI")
|
||||
summaryAllowed: bool = Field(
|
||||
default=False,
|
||||
description="Whether host may summarize content before sending to AI",
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"PromptPlaceholder",
|
||||
{"en": "Prompt Placeholder", "fr": "Espace réservé d'invite"},
|
||||
|
|
@ -627,11 +941,15 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class PromptBundle(BaseModel, ModelMixin):
|
||||
prompt: str
|
||||
placeholders: List[PromptPlaceholder] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"PromptBundle", {"en": "Prompt Bundle", "fr": "Lot d'invite"},
|
||||
"PromptBundle",
|
||||
{"en": "Prompt Bundle", "fr": "Lot d'invite"},
|
||||
{
|
||||
"prompt": {"en": "Prompt", "fr": "Invite"},
|
||||
"placeholders": {"en": "Placeholders", "fr": "Espaces réservés"},
|
||||
|
|
|
|||
216
modules/datamodels/datamodelChatbot.py
Normal file
216
modules/datamodels/datamodelChatbot.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""Chatbot API models for request/response handling."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
|
||||
|
||||
# Chatbot API Models
|
||||
class MessageItem(BaseModel, ModelMixin):
|
||||
"""Individual message in a thread"""
|
||||
|
||||
role: str = Field(..., description="Message role (user or assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel, ModelMixin):
|
||||
"""Request model for posting a chat message"""
|
||||
|
||||
thread_id: Optional[str] = Field(
|
||||
None, description="Thread ID (creates new thread if not provided)"
|
||||
)
|
||||
message: str = Field(..., description="User message content")
|
||||
tools: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of tool IDs to use. If not provided, all user's tools will be used",
|
||||
)
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel, ModelMixin):
|
||||
"""Response model for posting a chat message"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
||||
|
||||
|
||||
class ThreadSummary(BaseModel, ModelMixin):
|
||||
"""Summary of a chat thread for list view"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
thread_name: str = Field(..., description="Thread name")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
|
||||
|
||||
class ThreadListResponse(BaseModel, ModelMixin):
|
||||
"""Response model for listing all threads"""
|
||||
|
||||
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
||||
|
||||
|
||||
class ThreadDetail(BaseModel, ModelMixin):
|
||||
"""Detailed view of a single thread"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
date_created: float = Field(..., description="Thread creation timestamp")
|
||||
date_updated: float = Field(..., description="Thread last updated timestamp")
|
||||
messages: List[MessageItem] = Field(
|
||||
..., description="All messages in chronological order"
|
||||
)
|
||||
|
||||
|
||||
class RenameThreadRequest(BaseModel, ModelMixin):
|
||||
"""Request model for renaming a thread"""
|
||||
|
||||
new_name: str = Field(..., description="New name for the thread")
|
||||
|
||||
|
||||
class DeleteResponse(BaseModel, ModelMixin):
|
||||
"""Response model for delete operations"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
thread_id: str = Field(..., description="Deleted thread ID")
|
||||
|
||||
|
||||
# Tool Management Models
|
||||
class ToolInfo(BaseModel, ModelMixin):
|
||||
"""Information about a chatbot tool"""
|
||||
|
||||
id: str = Field(..., description="Tool UUID")
|
||||
tool_id: str = Field(
|
||||
..., description="Tool identifier (e.g., 'shared.tavily_search')"
|
||||
)
|
||||
name: str = Field(..., description="Tool function name")
|
||||
label: str = Field(..., description="Display label for the tool")
|
||||
category: str = Field(..., description="Tool category (shared or customer)")
|
||||
description: str = Field(..., description="Tool description")
|
||||
is_active: bool = Field(..., description="Whether the tool is active")
|
||||
date_created: float = Field(..., description="Creation timestamp")
|
||||
date_updated: float = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel, ModelMixin):
|
||||
"""Response model for listing all tools"""
|
||||
|
||||
tools: List[ToolInfo] = Field(..., description="List of available tools")
|
||||
|
||||
|
||||
class GrantToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for granting a tool to a user"""
|
||||
|
||||
user_id: str = Field(..., description="User ID to grant the tool to")
|
||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
||||
|
||||
|
||||
class GrantToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after granting a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
|
||||
|
||||
class RevokeToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for revoking a tool from a user"""
|
||||
|
||||
user_id: str = Field(..., description="User ID to revoke the tool from")
|
||||
tool_id: str = Field(..., description="Tool UUID from tools table")
|
||||
|
||||
|
||||
class RevokeToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after revoking a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
|
||||
|
||||
class UpdateToolRequest(BaseModel, ModelMixin):
|
||||
"""Request model for updating a tool's label and description"""
|
||||
|
||||
label: Optional[str] = Field(None, description="New label for the tool")
|
||||
description: Optional[str] = Field(None, description="New description for the tool")
|
||||
|
||||
|
||||
class UpdateToolResponse(BaseModel, ModelMixin):
|
||||
"""Response model after updating a tool"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
tool_id: str = Field(..., description="Tool UUID")
|
||||
updated_fields: List[str] = Field(..., description="List of updated field names")
|
||||
|
||||
|
||||
# Register model labels for internationalization
|
||||
register_model_labels(
|
||||
"MessageItem",
|
||||
{"en": "Message Item", "fr": "Élément de message"},
|
||||
{
|
||||
"role": {"en": "Role", "fr": "Rôle"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageRequest",
|
||||
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageResponse",
|
||||
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadSummary",
|
||||
{"en": "Thread Summary", "fr": "Résumé du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"thread_name": {"en": "Thread Name", "fr": "Nom du fil"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadListResponse",
|
||||
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
|
||||
{
|
||||
"threads": {"en": "Threads", "fr": "Fils"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadDetail",
|
||||
{"en": "Thread Detail", "fr": "Détail du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"date_created": {"en": "Date Created", "fr": "Date de création"},
|
||||
"date_updated": {"en": "Date Updated", "fr": "Date de mise à jour"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"RenameThreadRequest",
|
||||
{"en": "Rename Thread Request", "fr": "Demande de renommage de fil"},
|
||||
{
|
||||
"new_name": {"en": "New Name", "fr": "Nouveau nom"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"DeleteResponse",
|
||||
{"en": "Delete Response", "fr": "Réponse de suppression"},
|
||||
{
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
},
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""Security models: Token and AuthEvent."""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from .datamodelUam import AuthAuthority
|
||||
|
|
@ -13,25 +13,44 @@ class TokenStatus(str, Enum):
|
|||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class Token(BaseModel, ModelMixin):
|
||||
id: Optional[str] = None
|
||||
userId: str
|
||||
authority: AuthAuthority
|
||||
connectionId: Optional[str] = Field(None, description="ID of the connection this token belongs to")
|
||||
connectionId: Optional[str] = Field(
|
||||
None, description="ID of the connection this token belongs to"
|
||||
)
|
||||
tokenAccess: str
|
||||
tokenType: str = "bearer"
|
||||
expiresAt: float = Field(description="When the token expires (UTC timestamp in seconds)")
|
||||
expiresAt: float = Field(
|
||||
description="When the token expires (UTC timestamp in seconds)"
|
||||
)
|
||||
tokenRefresh: Optional[str] = None
|
||||
createdAt: Optional[float] = Field(None, description="When the token was created (UTC timestamp in seconds)")
|
||||
status: TokenStatus = Field(default=TokenStatus.ACTIVE, description="Token status: active/revoked")
|
||||
revokedAt: Optional[float] = Field(None, description="When the token was revoked (UTC timestamp in seconds)")
|
||||
revokedBy: Optional[str] = Field(None, description="User ID who revoked the token (admin/self)")
|
||||
createdAt: Optional[float] = Field(
|
||||
None, description="When the token was created (UTC timestamp in seconds)"
|
||||
)
|
||||
status: TokenStatus = Field(
|
||||
default=TokenStatus.ACTIVE, description="Token status: active/revoked"
|
||||
)
|
||||
revokedAt: Optional[float] = Field(
|
||||
None, description="When the token was revoked (UTC timestamp in seconds)"
|
||||
)
|
||||
revokedBy: Optional[str] = Field(
|
||||
None, description="User ID who revoked the token (admin/self)"
|
||||
)
|
||||
reason: Optional[str] = Field(None, description="Optional revocation reason")
|
||||
sessionId: Optional[str] = Field(None, description="Logical session grouping for logout revocation")
|
||||
mandateId: Optional[str] = Field(None, description="Mandate ID for tenant scoping of the token")
|
||||
sessionId: Optional[str] = Field(
|
||||
None, description="Logical session grouping for logout revocation"
|
||||
)
|
||||
mandateId: Optional[str] = Field(
|
||||
None, description="Mandate ID for tenant scoping of the token"
|
||||
)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"Token",
|
||||
{"en": "Token", "fr": "Jeton"},
|
||||
|
|
@ -54,15 +73,64 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class AuthEvent(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the auth event", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userId: str = Field(description="ID of the user this event belongs to", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
eventType: str = Field(description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="Unix timestamp when the event occurred", frontend_type="datetime", frontend_readonly=True, frontend_required=True)
|
||||
ipAddress: Optional[str] = Field(default=None, description="IP address from which the event originated", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userAgent: Optional[str] = Field(default=None, description="User agent string from the request", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
success: bool = Field(default=True, description="Whether the authentication event was successful", frontend_type="boolean", frontend_readonly=True, frontend_required=True)
|
||||
details: Optional[str] = Field(default=None, description="Additional details about the event", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the auth event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user this event belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
eventType: str = Field(
|
||||
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Unix timestamp when the event occurred",
|
||||
frontend_type="datetime",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
ipAddress: Optional[str] = Field(
|
||||
default=None,
|
||||
description="IP address from which the event originated",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userAgent: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User agent string from the request",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True,
|
||||
description="Whether the authentication event was successful",
|
||||
frontend_type="boolean",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
details: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Additional details about the event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
"AuthEvent",
|
||||
{"en": "Authentication Event", "fr": "Événement d'authentification"},
|
||||
|
|
@ -77,5 +145,3 @@ register_model_labels(
|
|||
"details": {"en": "Details", "fr": "Détails"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
1
modules/features/chatBot/chatbotTools/__init__.py
Normal file
1
modules/features/chatBot/chatbotTools/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Contains all tools available for the chatbot to use."""
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tools that are shared between multiple customers go here."""
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
"""Althaus Database Query Tool for LangGraph.
|
||||
|
||||
This tool provides database query capabilities for the Althaus database
|
||||
via an external REST API. Only SELECT queries are allowed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _mock_api_call(*, sql_query: str) -> dict:
|
||||
"""Mock the external REST API call to Althaus database.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute
|
||||
|
||||
Returns:
|
||||
A dictionary containing the query results with columns and rows
|
||||
"""
|
||||
# Simulate network delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Mock response data based on common query patterns
|
||||
if "users" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["id", "username", "email", "created_at"],
|
||||
"rows": [
|
||||
[1, "john_doe", "john@example.com", "2024-01-15"],
|
||||
[2, "jane_smith", "jane@example.com", "2024-02-20"],
|
||||
[3, "bob_wilson", "bob@example.com", "2024-03-10"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "products" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["product_id", "name", "price", "stock"],
|
||||
"rows": [
|
||||
[101, "Widget A", 29.99, 150],
|
||||
[102, "Widget B", 39.99, 75],
|
||||
[103, "Widget C", 19.99, 200],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
elif "orders" in sql_query.lower():
|
||||
return {
|
||||
"columns": ["order_id", "customer_id", "total", "status"],
|
||||
"rows": [
|
||||
[5001, 1, 129.99, "completed"],
|
||||
[5002, 2, 89.50, "pending"],
|
||||
[5003, 1, 199.99, "shipped"],
|
||||
],
|
||||
"row_count": 3,
|
||||
}
|
||||
else:
|
||||
# Generic response for other queries
|
||||
return {
|
||||
"columns": ["id", "value", "description"],
|
||||
"rows": [
|
||||
[1, "Sample 1", "First sample entry"],
|
||||
[2, "Sample 2", "Second sample entry"],
|
||||
],
|
||||
"row_count": 2,
|
||||
}
|
||||
|
||||
|
||||
def _validate_select_query(*, sql_query: str) -> tuple[bool, str]:
|
||||
"""Validate that the query is a SELECT statement only.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL query to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Remove leading/trailing whitespace and convert to lowercase for checking
|
||||
normalized_query = sql_query.strip().lower()
|
||||
|
||||
# Check if query starts with SELECT
|
||||
if not normalized_query.startswith("select"):
|
||||
return False, "Query must be a SELECT statement"
|
||||
|
||||
# Check for dangerous keywords that should not be in a SELECT query
|
||||
dangerous_keywords = [
|
||||
"insert",
|
||||
"update",
|
||||
"delete",
|
||||
"drop",
|
||||
"create",
|
||||
"alter",
|
||||
"truncate",
|
||||
"grant",
|
||||
"revoke",
|
||||
"exec",
|
||||
"execute",
|
||||
]
|
||||
|
||||
for keyword in dangerous_keywords:
|
||||
# Use word boundary to match whole words only
|
||||
if re.search(rf"\b{keyword}\b", normalized_query):
|
||||
return False, f"Query contains forbidden keyword: {keyword.upper()}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _format_results(*, columns: list[str], rows: list[list], row_count: int) -> str:
|
||||
"""Format query results into a readable string.
|
||||
|
||||
Args:
|
||||
columns: List of column names
|
||||
rows: List of row data
|
||||
row_count: Total number of rows
|
||||
|
||||
Returns:
|
||||
Formatted string representation of the results
|
||||
"""
|
||||
if row_count == 0:
|
||||
return "Query executed successfully but returned no results."
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(str(col)) for col in columns]
|
||||
for row in rows:
|
||||
for i, cell in enumerate(row):
|
||||
col_widths[i] = max(col_widths[i], len(str(cell)))
|
||||
|
||||
# Build header
|
||||
header_parts = []
|
||||
for col, width in zip(columns, col_widths):
|
||||
header_parts.append(str(col).ljust(width))
|
||||
header = " | ".join(header_parts)
|
||||
separator = "-" * len(header)
|
||||
|
||||
# Build rows
|
||||
row_lines = []
|
||||
for row in rows:
|
||||
row_parts = []
|
||||
for cell, width in zip(row, col_widths):
|
||||
row_parts.append(str(cell).ljust(width))
|
||||
row_lines.append(" | ".join(row_parts))
|
||||
|
||||
# Combine all parts
|
||||
result_parts = [
|
||||
f"Query returned {row_count} row(s):\n",
|
||||
header,
|
||||
separator,
|
||||
"\n".join(row_lines),
|
||||
]
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
@tool
|
||||
async def query_althaus_database(
|
||||
sql_query: Annotated[
|
||||
str, "The SQL SELECT query to execute against the Althaus database"
|
||||
],
|
||||
) -> str:
|
||||
"""Execute a SELECT query against the Althaus database via REST API.
|
||||
|
||||
Use this tool to query data from the Althaus database. Only SELECT statements
|
||||
are allowed for security reasons. The query will be forwarded to an external
|
||||
REST API and the results will be returned in a formatted table.
|
||||
|
||||
Args:
|
||||
sql_query: The SQL SELECT query to execute (e.g., "SELECT * FROM users WHERE id = 1")
|
||||
|
||||
Returns:
|
||||
A formatted string containing the query results with columns and rows
|
||||
"""
|
||||
try:
|
||||
# Validate the query
|
||||
is_valid, error_msg = _validate_select_query(sql_query=sql_query)
|
||||
if not is_valid:
|
||||
logger.warning(f"Invalid query attempt: {sql_query[:100]}...")
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
logger.info(f"Executing Althaus database query: {sql_query[:100]}...")
|
||||
|
||||
# Mock the external REST API call
|
||||
# In production, this would be replaced with actual REST API call:
|
||||
# response = await httpx.AsyncClient().post(
|
||||
# "https://api.althaus.example.com/query",
|
||||
# json={"query": sql_query},
|
||||
# headers={"Authorization": f"Bearer {api_key}"}
|
||||
# )
|
||||
# result = response.json()
|
||||
|
||||
result = await _mock_api_call(sql_query=sql_query)
|
||||
|
||||
# Format and return results
|
||||
formatted_output = _format_results(
|
||||
columns=result["columns"],
|
||||
rows=result["rows"],
|
||||
row_count=result["row_count"],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Query completed successfully, returned {result['row_count']} row(s)"
|
||||
)
|
||||
return formatted_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in query_althaus_database tool: {str(e)}")
|
||||
return f"Error executing query: {str(e)}"
|
||||
|
|
@ -0,0 +1,362 @@
|
|||
"""Power BI Query Tool for LangGraph.
|
||||
|
||||
This tool provides DAX query capabilities for Power BI datasets
|
||||
via the Power BI REST API. Only read-only queries are allowed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import functools
|
||||
from typing import Annotated
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from msal import ConfidentialClientApplication, SerializableTokenCache
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration constants - encapsulated in this file
|
||||
POWERBI_DATASET_ID = APP_CONFIG.get("VALUEON_POWERBI_DATASET_ID", "")
|
||||
POWERBI_CLIENT_ID = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_ID", "")
|
||||
POWERBI_CLIENT_SECRET = APP_CONFIG.get("VALUEON_POWERBI_CLIENT_SECRET", "")
|
||||
POWERBI_TENANT_ID = APP_CONFIG.get("VALUEON_POWERBI_TENANT_ID", "")
|
||||
POWERBI_BASE_URL = "https://api.powerbi.com/v1.0/myorg"
|
||||
POWERBI_AUTHORITY_BASE = "https://login.microsoftonline.com"
|
||||
POWERBI_SCOPE = ["https://analysis.windows.net/powerbi/api/.default"]
|
||||
|
||||
# Limit results to prevent excessive context usage
|
||||
MAX_ROWS_LIMIT = 100
|
||||
|
||||
|
||||
def _validate_environment() -> tuple[bool, str]:
|
||||
"""Validate that all required environment variables are set.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
missing = []
|
||||
if not POWERBI_DATASET_ID:
|
||||
missing.append("POWERBI_DATASET_ID")
|
||||
if not POWERBI_CLIENT_ID:
|
||||
missing.append("POWERBI_CLIENT_ID")
|
||||
if not POWERBI_CLIENT_SECRET:
|
||||
missing.append("POWERBI_CLIENT_SECRET")
|
||||
if not POWERBI_TENANT_ID:
|
||||
missing.append("POWERBI_TENANT_ID")
|
||||
|
||||
if missing:
|
||||
return False, f"Missing required environment variables: {', '.join(missing)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _validate_dax_query(*, dax_query: str) -> tuple[bool, str]:
|
||||
"""Validate that the query is a valid DAX query.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Remove leading/trailing whitespace
|
||||
normalized_query = dax_query.strip()
|
||||
|
||||
if not normalized_query:
|
||||
return False, "Query cannot be empty"
|
||||
|
||||
# DAX queries typically start with EVALUATE, DEFINE, or are table expressions
|
||||
# We'll be lenient and just check it's not trying to do something dangerous
|
||||
# DAX is read-only by nature, but we validate structure
|
||||
|
||||
# Check for minimum length
|
||||
if len(normalized_query) < 5:
|
||||
return False, "Query is too short to be valid"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _get_access_token_sync(
|
||||
*,
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
authority_base: str = POWERBI_AUTHORITY_BASE,
|
||||
cache: SerializableTokenCache | None = None,
|
||||
) -> str:
|
||||
"""Get Power BI access token using MSAL (synchronous).
|
||||
|
||||
Args:
|
||||
tenant_id: Azure AD tenant ID
|
||||
client_id: Application client ID
|
||||
client_secret: Application client secret
|
||||
authority_base: Azure AD authority base URL
|
||||
cache: Optional token cache for reuse
|
||||
|
||||
Returns:
|
||||
Access token string
|
||||
|
||||
Raises:
|
||||
RuntimeError: If token acquisition fails
|
||||
"""
|
||||
authority = f"{authority_base}/{tenant_id}"
|
||||
|
||||
app = ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
authority=authority,
|
||||
client_credential=client_secret,
|
||||
token_cache=cache,
|
||||
)
|
||||
|
||||
# Try cache first; fall back to client credentials
|
||||
result = app.acquire_token_silent(
|
||||
POWERBI_SCOPE, account=None
|
||||
) or app.acquire_token_for_client(scopes=POWERBI_SCOPE)
|
||||
|
||||
if "access_token" not in result:
|
||||
raise RuntimeError(
|
||||
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
|
||||
)
|
||||
|
||||
return result["access_token"]
|
||||
|
||||
|
||||
async def _get_access_token_async(
|
||||
*,
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Get Power BI access token using MSAL (asynchronous).
|
||||
|
||||
Args:
|
||||
tenant_id: Azure AD tenant ID
|
||||
client_id: Application client ID
|
||||
client_secret: Application client secret
|
||||
**kwargs: Additional arguments for _get_access_token_sync
|
||||
|
||||
Returns:
|
||||
Access token string
|
||||
"""
|
||||
# Create a partial function with arguments pre-filled
|
||||
func = functools.partial(
|
||||
_get_access_token_sync,
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
**kwargs,
|
||||
)
|
||||
# Offload the blocking MSAL HTTP call to a worker thread
|
||||
return await anyio.to_thread.run_sync(func)
|
||||
|
||||
|
||||
async def _execute_dax_query(
|
||||
*, dax_query: str, dataset_id: str, access_token: str
|
||||
) -> dict:
|
||||
"""Execute a DAX query against Power BI dataset.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to execute
|
||||
dataset_id: Power BI dataset ID
|
||||
access_token: Access token for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary containing query results
|
||||
|
||||
Raises:
|
||||
RuntimeError: If query execution fails
|
||||
"""
|
||||
url = f"{POWERBI_BASE_URL}/datasets/{dataset_id}/executeQueries"
|
||||
|
||||
body = {
|
||||
"queries": [{"query": dax_query}],
|
||||
"serializerSettings": {"includeNulls": True},
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(url, headers=headers, json=body)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
|
||||
)
|
||||
|
||||
payload = resp.json()
|
||||
|
||||
try:
|
||||
rows = payload["results"][0]["tables"][0]["rows"]
|
||||
except (KeyError, IndexError) as e:
|
||||
raise RuntimeError("Unexpected executeQueries response structure") from e
|
||||
|
||||
# Extract column names from the first row if available
|
||||
if rows:
|
||||
columns = list(rows[0].keys())
|
||||
else:
|
||||
columns = []
|
||||
|
||||
return {"columns": columns, "rows": rows}
|
||||
|
||||
|
||||
def _strip_table_qualifier(*, column_name: str) -> str:
|
||||
"""Strip table qualifier from column name.
|
||||
|
||||
Power BI often returns columns as 'Table[Column]'. This strips to 'Column'.
|
||||
|
||||
Args:
|
||||
column_name: The column name to process
|
||||
|
||||
Returns:
|
||||
Processed column name
|
||||
"""
|
||||
if "[" in column_name and column_name.endswith("]"):
|
||||
return column_name.split("[", 1)[1][:-1]
|
||||
return column_name
|
||||
|
||||
|
||||
def _format_results(*, columns: list[str], rows: list[dict], max_rows: int) -> str:
|
||||
"""Format query results into a readable string.
|
||||
|
||||
Args:
|
||||
columns: List of column names
|
||||
rows: List of row data (as dictionaries)
|
||||
max_rows: Maximum number of rows to display
|
||||
|
||||
Returns:
|
||||
Formatted string representation of the results
|
||||
"""
|
||||
total_rows = len(rows)
|
||||
|
||||
if total_rows == 0:
|
||||
return "Query executed successfully but returned no results."
|
||||
|
||||
# Strip table qualifiers from column names
|
||||
clean_columns = [_strip_table_qualifier(column_name=col) for col in columns]
|
||||
|
||||
# Limit rows to max_rows
|
||||
display_rows = rows[:max_rows]
|
||||
truncated = total_rows > max_rows
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(str(col)) for col in clean_columns]
|
||||
for row in display_rows:
|
||||
for i, col in enumerate(columns):
|
||||
value = row.get(col, "")
|
||||
col_widths[i] = max(col_widths[i], len(str(value)))
|
||||
|
||||
# Build header
|
||||
header_parts = []
|
||||
for col, width in zip(clean_columns, col_widths):
|
||||
header_parts.append(str(col).ljust(width))
|
||||
header = " | ".join(header_parts)
|
||||
separator = "-" * len(header)
|
||||
|
||||
# Build rows
|
||||
row_lines = []
|
||||
for row in display_rows:
|
||||
row_parts = []
|
||||
for col, width in zip(columns, col_widths):
|
||||
value = row.get(col, "")
|
||||
row_parts.append(str(value).ljust(width))
|
||||
row_lines.append(" | ".join(row_parts))
|
||||
|
||||
# Combine all parts
|
||||
result_parts = [
|
||||
f"Query returned {total_rows} row(s):",
|
||||
]
|
||||
|
||||
if truncated:
|
||||
result_parts.append(
|
||||
f"(Results limited to {max_rows} rows for context efficiency)\n"
|
||||
)
|
||||
else:
|
||||
result_parts.append("")
|
||||
|
||||
result_parts.extend([header, separator, "\n".join(row_lines)])
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
@tool
|
||||
async def query_powerbi_data(
|
||||
dax_query: Annotated[str, "The DAX query to execute against the Power BI dataset"],
|
||||
) -> str:
|
||||
"""Execute a DAX query against the Power BI dataset to access warehouse inventory data.
|
||||
|
||||
This tool provides access to a Power BI table called 'data_full' which contains
|
||||
articles available in the warehouse of the user. Use DAX (Data Analysis Expressions)
|
||||
queries to retrieve and analyze this inventory data.
|
||||
|
||||
Available table:
|
||||
- 'data_full': Contains warehouse inventory articles and their details
|
||||
|
||||
Common query patterns:
|
||||
- View all data: EVALUATE 'data_full'
|
||||
- With filter: EVALUATE FILTER('data_full', [Column] = "Value")
|
||||
- Top N rows: EVALUATE TOPN(10, 'data_full', [Column], DESC)
|
||||
- Calculated: EVALUATE SUMMARIZE('data_full', [Column1], "Total", SUM([Column2]))
|
||||
|
||||
Results are limited to 100 rows maximum for efficiency.
|
||||
|
||||
Args:
|
||||
dax_query: The DAX query to execute (e.g., "EVALUATE 'data_full'")
|
||||
|
||||
Returns:
|
||||
A formatted string containing the query results with columns and rows
|
||||
"""
|
||||
try:
|
||||
# Validate environment configuration
|
||||
is_valid_env, error_msg = _validate_environment()
|
||||
if not is_valid_env:
|
||||
logger.error(f"Environment validation failed: {error_msg}")
|
||||
return f"Configuration Error: {error_msg}"
|
||||
|
||||
# Validate the query
|
||||
is_valid_query, error_msg = _validate_dax_query(dax_query=dax_query)
|
||||
if not is_valid_query:
|
||||
logger.warning(f"Invalid query attempt: {dax_query[:100]}...")
|
||||
return f"Query Validation Error: {error_msg}"
|
||||
|
||||
logger.info(f"Executing Power BI query: {dax_query[:100]}...")
|
||||
|
||||
# Get access token
|
||||
access_token = await _get_access_token_async(
|
||||
tenant_id=POWERBI_TENANT_ID,
|
||||
client_id=POWERBI_CLIENT_ID,
|
||||
client_secret=POWERBI_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await _execute_dax_query(
|
||||
dax_query=dax_query,
|
||||
dataset_id=POWERBI_DATASET_ID,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Format and return results
|
||||
formatted_output = _format_results(
|
||||
columns=result["columns"], rows=result["rows"], max_rows=MAX_ROWS_LIMIT
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Query completed successfully, returned {len(result['rows'])} row(s)"
|
||||
)
|
||||
return formatted_output
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Runtime error in query_powerbi_data tool: {str(e)}")
|
||||
return f"Error executing query: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in query_powerbi_data tool: {str(e)}")
|
||||
return f"Unexpected error: {str(e)}"
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Shared tools available across all chatbot implementations."""
|
||||
|
||||
from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import (
|
||||
tavily_search,
|
||||
)
|
||||
|
||||
__all__ = ["tavily_search"]
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Tool for sending streaming status updates to users."""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def send_streaming_message(message: str) -> str:
|
||||
"""Send a streaming message to the user to provide updates during processing.
|
||||
|
||||
Use this tool to send short status updates to the user while you are working
|
||||
on their request. This helps keep the user informed about what you are doing.
|
||||
|
||||
Args:
|
||||
message: A short message describing what you are currently doing.
|
||||
Examples: "Searching database for relevant information..."
|
||||
"Analyzing search results..."
|
||||
"Processing your request..."
|
||||
|
||||
Returns:
|
||||
A confirmation that the message was sent.
|
||||
"""
|
||||
# This tool doesn't actually do anything - it's just for the AI to signal
|
||||
# what it's doing to the frontend via the tool call mechanism
|
||||
return f"Status update sent: {message}"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""Tavily Search Tool for LangGraph.
|
||||
|
||||
This tool provides web search capabilities using the Tavily API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
from modules.connectors.connectorAiTavily import ConnectorWeb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
async def tavily_search(
|
||||
query: Annotated[str, "The search query to look up on the web"],
|
||||
) -> str:
|
||||
"""Search the web using Tavily API.
|
||||
|
||||
Use this tool to search for current information, news, or any web content.
|
||||
The tool returns relevant search results including titles and URLs.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
|
||||
Returns:
|
||||
A formatted string containing search results with titles and URLs
|
||||
"""
|
||||
try:
|
||||
# Create connector instance
|
||||
connector = await ConnectorWeb.create()
|
||||
|
||||
# Perform search with default parameters
|
||||
results = await connector._search(
|
||||
query=query,
|
||||
max_results=5,
|
||||
search_depth="basic",
|
||||
include_answer=True,
|
||||
include_raw_content=False,
|
||||
)
|
||||
|
||||
# Format results
|
||||
if not results:
|
||||
return f"No results found for query: {query}"
|
||||
|
||||
formatted_results = [f"Search results for '{query}':\n"]
|
||||
for i, result in enumerate(results, 1):
|
||||
formatted_results.append(f"{i}. {result.title}")
|
||||
formatted_results.append(f" URL: {result.url}\n")
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tavily_search tool: {str(e)}")
|
||||
return f"Error performing search: {str(e)}"
|
||||
197
modules/features/chatBot/database.py
Normal file
197
modules/features/chatBot/database.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
from typing import AsyncIterator
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
# Tools Table
|
||||
class Tool(Base):
|
||||
"""Available chatbot tools.
|
||||
|
||||
Stores information about all available tools that can be assigned to users.
|
||||
Each tool has a unique tool_id that corresponds to the registry tool_id.
|
||||
"""
|
||||
|
||||
__tablename__ = "tools"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(1000), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
date_created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# User-Tool Mapping Table
|
||||
class UserToolMapping(Base):
|
||||
"""Mapping of users to their available tools.
|
||||
|
||||
Many-to-many relationship between users and tools.
|
||||
- One user can have multiple tools
|
||||
- One tool can be assigned to multiple users
|
||||
|
||||
The combination of user_id and tool_id is unique.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_tools"
|
||||
__table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
date_granted: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# User Thread Mapping Table
|
||||
class UserThreadMapping(Base):
|
||||
"""Mapping of users to their chat threads.
|
||||
|
||||
Used to keep track of which user owns which chat thread.
|
||||
Also stores meta data like thread name.
|
||||
|
||||
1:N relationship between user and thread. A thread belongs to exactly one user.
|
||||
A user can have multiple threads.
|
||||
Thread_id is unique in the table.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_threads"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
thread_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
date_created: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
date_updated: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# Dependency that pulls the sessionmaker off app.state
|
||||
# This is set in app.py on startup in @asynccontextmanager
|
||||
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
|
||||
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
SessionLocal: async_sessionmaker[AsyncSession] = (
|
||||
request.app.state.checkpoint_sessionmaker
|
||||
)
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Optional helper to init tables at startup (demo only)
|
||||
async def init_models(engine) -> None:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def sync_tools_from_registry(session: AsyncSession) -> None:
|
||||
"""Sync tools from tool registry to database.
|
||||
|
||||
This function:
|
||||
- Adds new tools from the registry to the database
|
||||
- Updates existing tools with current registry information
|
||||
- Marks tools not present in the registry as inactive
|
||||
|
||||
Should be called on application startup after database initialization.
|
||||
|
||||
Args:
|
||||
session: Active database session
|
||||
"""
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Syncing tools from registry to database...")
|
||||
|
||||
# Get all tools from the registry
|
||||
registry = get_registry()
|
||||
registry_tools = registry.get_all_tools()
|
||||
|
||||
# Create a set of tool_ids from the registry
|
||||
registry_tool_ids = {tool.tool_id for tool in registry_tools}
|
||||
|
||||
logger.info(f"Found {len(registry_tools)} tools in registry")
|
||||
|
||||
# Get all existing tools from the database
|
||||
result = await session.execute(select(Tool))
|
||||
db_tools = result.scalars().all()
|
||||
db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools}
|
||||
|
||||
logger.info(f"Found {len(db_tools)} tools in database")
|
||||
|
||||
# Track changes
|
||||
added_count = 0
|
||||
updated_count = 0
|
||||
deactivated_count = 0
|
||||
|
||||
# Sync tools from registry to database
|
||||
for registry_tool in registry_tools:
|
||||
if registry_tool.tool_id in db_tools_by_tool_id:
|
||||
# Tool exists - update it
|
||||
# Preserve label and description (user-editable fields)
|
||||
db_tool = db_tools_by_tool_id[registry_tool.tool_id]
|
||||
db_tool.name = registry_tool.name
|
||||
db_tool.category = registry_tool.category
|
||||
db_tool.is_active = True
|
||||
db_tool.date_updated = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
logger.debug(f"Updated tool: {registry_tool.tool_id}")
|
||||
else:
|
||||
# Tool doesn't exist - create it
|
||||
new_tool = Tool(
|
||||
tool_id=registry_tool.tool_id,
|
||||
name=registry_tool.name,
|
||||
label=registry_tool.tool_id, # Use tool_id as label per spec
|
||||
category=registry_tool.category,
|
||||
description=registry_tool.description or "",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(new_tool)
|
||||
added_count += 1
|
||||
logger.debug(f"Added new tool: {registry_tool.tool_id}")
|
||||
|
||||
# Mark tools not in registry as inactive
|
||||
for db_tool in db_tools:
|
||||
if db_tool.tool_id not in registry_tool_ids and db_tool.is_active:
|
||||
db_tool.is_active = False
|
||||
db_tool.date_updated = datetime.now(timezone.utc)
|
||||
deactivated_count += 1
|
||||
logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}")
|
||||
|
||||
logger.info(
|
||||
f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated"
|
||||
)
|
||||
1
modules/features/chatBot/domain/__init__.py
Normal file
1
modules/features/chatBot/domain/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Domain logic for chatbot functionality."""
|
||||
301
modules/features/chatBot/domain/chatbot.py
Normal file
301
modules/features/chatBot/domain/chatbot.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
"""Chatbot domain logic with LangGraph integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, AsyncIterator, Any
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatState(BaseModel):
|
||||
"""Represents the state of a chat session."""
|
||||
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
|
||||
|
||||
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
||||
"""Map permission model names to LangChain ChatAnthropic models.
|
||||
|
||||
Args:
|
||||
model_name: The model name from permissions (e.g., "claude_4_5")
|
||||
|
||||
Returns:
|
||||
Configured ChatAnthropic instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the model name is not supported
|
||||
"""
|
||||
# Model name mapping
|
||||
model_mapping = {
|
||||
"claude_4_5": "claude-sonnet-4-5",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
|
||||
anthropic_model = model_mapping.get(model_name)
|
||||
if not anthropic_model:
|
||||
logger.warning(
|
||||
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
|
||||
)
|
||||
anthropic_model = "claude-4-5-sonnet"
|
||||
|
||||
return ChatAnthropic(
|
||||
model=anthropic_model,
|
||||
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
|
||||
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
|
||||
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chatbot:
|
||||
"""Represents a chatbot with LangGraph integration."""
|
||||
|
||||
model: Any
|
||||
memory: Any
|
||||
app: Any = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
context_window_size: int = 100000
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
*,
|
||||
model: Any,
|
||||
memory: Any,
|
||||
system_prompt: str,
|
||||
tools: list,
|
||||
context_window_size: int = 100000,
|
||||
) -> "Chatbot":
|
||||
"""Factory method to create and configure a Chatbot instance.
|
||||
|
||||
Args:
|
||||
model: The chat model to use.
|
||||
memory: The chat memory checkpointer to use.
|
||||
system_prompt: The system prompt to initialize the chatbot.
|
||||
tools: List of LangChain tools the chatbot can use.
|
||||
context_window_size: Maximum tokens for context window.
|
||||
|
||||
Returns:
|
||||
A configured Chatbot instance.
|
||||
"""
|
||||
instance = cls(
|
||||
model=model,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
instance.app = instance._build_app(memory=memory, tools=tools)
|
||||
return instance
|
||||
|
||||
def _build_app(
|
||||
self, *, memory: Any, tools: list
|
||||
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
||||
"""Builds the chatbot application workflow using LangGraph.
|
||||
|
||||
Args:
|
||||
memory: The chat memory checkpointer to use.
|
||||
tools: The list of tools the chatbot can use.
|
||||
|
||||
Returns:
|
||||
A compiled state graph representing the chatbot application.
|
||||
"""
|
||||
llm_with_tools = self.model.bind_tools(tools=tools)
|
||||
|
||||
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Selects a window of messages that fit within the context window size.
|
||||
|
||||
Args:
|
||||
msgs: The list of messages to select from.
|
||||
|
||||
Returns:
|
||||
A list of messages that fit within the context window size.
|
||||
"""
|
||||
|
||||
def approx_counter(items: list[BaseMessage]) -> int:
|
||||
"""Approximate token counter for messages.
|
||||
|
||||
Args:
|
||||
items: List of messages to count tokens for.
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
"""
|
||||
return sum(len(getattr(m, "content", "") or "") for m in items)
|
||||
|
||||
return trim_messages(
|
||||
msgs,
|
||||
strategy="last",
|
||||
token_counter=approx_counter,
|
||||
max_tokens=self.context_window_size,
|
||||
start_on="human",
|
||||
end_on=("human", "tool"),
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
def agent_node(state: ChatState) -> dict:
|
||||
"""Agent node for the chatbot workflow.
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The updated chat state after processing.
|
||||
"""
|
||||
# Select the message window to fit in context (trim if needed)
|
||||
window = select_window(state.messages)
|
||||
|
||||
# Ensure the system prompt is present at the start
|
||||
if not window or not isinstance(window[0], SystemMessage):
|
||||
window = [SystemMessage(content=self.system_prompt)] + window
|
||||
|
||||
# Call the LLM with tools
|
||||
response = llm_with_tools.invoke(window)
|
||||
|
||||
# Return the new state
|
||||
return {"messages": [response]}
|
||||
|
||||
def should_continue(state: ChatState) -> str:
|
||||
"""Determines whether to continue the workflow or end it.
|
||||
|
||||
This conditional edge is called after the agent node to decide
|
||||
whether to continue to the tools node (if the last message contains
|
||||
tool calls) or to end the workflow (if no tool calls are present).
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The next node to transition to ("tools" or END).
|
||||
"""
|
||||
# Get the last message
|
||||
last_message = state.messages[-1]
|
||||
|
||||
# Check if the last message contains tool calls
|
||||
# If so, continue to the tools node; otherwise, end the workflow
|
||||
return "tools" if getattr(last_message, "tool_calls", None) else END
|
||||
|
||||
# Compose the workflow
|
||||
workflow = StateGraph(ChatState)
|
||||
workflow.add_node("agent", agent_node)
|
||||
workflow.add_node("tools", ToolNode(tools=tools))
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue)
|
||||
workflow.add_edge("tools", "agent")
|
||||
return workflow.compile(checkpointer=memory)
|
||||
|
||||
async def chat(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> list[BaseMessage]:
|
||||
"""Processes a chat message and returns the chat history.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: The chat thread ID.
|
||||
|
||||
Returns:
|
||||
The list of messages in the chat history.
|
||||
"""
|
||||
# Set the right thread ID for memory
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
# Single-turn chat (non-streaming)
|
||||
result = await self.app.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}, config=config
|
||||
)
|
||||
|
||||
# Extract and return the messages from the result
|
||||
return result["messages"]
|
||||
|
||||
async def stream_events(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Stream UI-focused events using astream_events v2.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: Logical thread identifier; forwarded in the runnable config so
|
||||
memory and tools are scoped per thread.
|
||||
|
||||
Yields:
|
||||
dict: One of:
|
||||
- ``{"type": "status", "label": str}`` for short progress updates.
|
||||
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
||||
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
||||
- ``{"type": "error", "message": str}`` if an exception occurs.
|
||||
"""
|
||||
# Thread-aware config for LangGraph/LangChain
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
def _is_root(ev: dict) -> bool:
|
||||
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
||||
return not ev.get("parent_ids")
|
||||
|
||||
try:
|
||||
async for event in self.app.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
config=config,
|
||||
version="v2",
|
||||
):
|
||||
etype = event.get("event")
|
||||
ename = event.get("name") or ""
|
||||
edata = event.get("data") or {}
|
||||
|
||||
# Stream human-readable progress via the special send_streaming_message tool
|
||||
if etype == "on_tool_start" and ename == "send_streaming_message":
|
||||
tool_in = edata.get("input") or {}
|
||||
msg = tool_in.get("message")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
yield {"type": "status", "label": msg.strip()}
|
||||
continue
|
||||
|
||||
# Emit the final payload when the root run finishes
|
||||
if etype == "on_chain_end" and _is_root(event):
|
||||
output_obj = edata.get("output")
|
||||
|
||||
# Extract message list from the graph's final output
|
||||
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
||||
output_obj=output_obj
|
||||
)
|
||||
|
||||
# Normalize for the frontend (only user/assistant with text content)
|
||||
chat_history_payload: list[dict] = []
|
||||
for m in final_msgs:
|
||||
if isinstance(m, BaseMessage):
|
||||
d = ChatStreamingHelper.message_to_dict(msg=m)
|
||||
elif isinstance(m, dict):
|
||||
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
||||
else:
|
||||
continue
|
||||
if d.get("role") in ("user", "assistant") and d.get("content"):
|
||||
chat_history_payload.append(d)
|
||||
|
||||
yield {
|
||||
"type": "final",
|
||||
"response": {
|
||||
"thread": chat_id,
|
||||
"chat_history": chat_history_payload,
|
||||
},
|
||||
}
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
# Emit a single error envelope and end the stream
|
||||
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
|
||||
yield {"type": "error", "message": f"Error processing request: {exc}"}
|
||||
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Streaming helper utilities for chat message processing and normalization."""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
Role = Literal["user", "assistant", "system", "tool"]
|
||||
|
||||
|
||||
class ChatStreamingHelper:
|
||||
"""Pure helper methods for streaming and message normalization.
|
||||
|
||||
This class provides static utility methods for converting between different
|
||||
message formats, extracting content, and normalizing message structures
|
||||
for streaming chat applications.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def role_from_message(*, msg: BaseMessage) -> Role:
|
||||
"""Extract the role from a BaseMessage instance.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to extract the role from.
|
||||
|
||||
Returns:
|
||||
The role as a string literal: "user", "assistant", "system", or "tool".
|
||||
Defaults to "assistant" if the message type is not recognized.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello")
|
||||
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
||||
'user'
|
||||
"""
|
||||
if isinstance(msg, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(msg, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(msg, SystemMessage):
|
||||
return "system"
|
||||
if isinstance(msg, ToolMessage):
|
||||
return "tool"
|
||||
return getattr(msg, "role", "assistant")
|
||||
|
||||
@staticmethod
|
||||
def flatten_content(*, content: Any) -> str:
|
||||
"""Convert complex content structures to plain text.
|
||||
|
||||
This method handles various content formats including strings, lists of
|
||||
content parts, and dictionaries with text fields. It's designed to
|
||||
normalize content from different message sources into a consistent
|
||||
plain text format.
|
||||
|
||||
Args:
|
||||
content: The content to flatten. Can be:
|
||||
- str: Returned as-is after stripping whitespace
|
||||
- list: Each item processed and joined with newlines
|
||||
- dict: Text extracted from "text" or "content" fields
|
||||
- None: Returns empty string
|
||||
- Any other type: Converted to string
|
||||
|
||||
Returns:
|
||||
The flattened content as a plain text string with whitespace stripped.
|
||||
|
||||
Examples:
|
||||
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Hello
|
||||
nworld'
|
||||
|
||||
>>> content = {"text": "Simple message"}
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Simple message'
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if "text" in part and isinstance(part["text"], str):
|
||||
parts.append(part["text"])
|
||||
elif part.get("type") == "text" and isinstance(
|
||||
part.get("text"), str
|
||||
):
|
||||
parts.append(part["text"])
|
||||
elif "content" in part and isinstance(part["content"], str):
|
||||
parts.append(part["content"])
|
||||
else:
|
||||
# Fallback for unknown dictionary structures
|
||||
val = part.get("value")
|
||||
if isinstance(val, str):
|
||||
parts.append(val)
|
||||
else:
|
||||
parts.append(str(part))
|
||||
return "\n".join(p.strip() for p in parts if p is not None)
|
||||
if isinstance(content, dict):
|
||||
if "text" in content and isinstance(content["text"], str):
|
||||
return content["text"].strip()
|
||||
if "content" in content and isinstance(content["content"], str):
|
||||
return content["content"].strip()
|
||||
return str(content).strip()
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
||||
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
||||
|
||||
This method normalizes BaseMessage instances into a consistent dictionary
|
||||
format suitable for JSON serialization and streaming to clients.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to convert.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello there")
|
||||
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello there'
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
||||
"content": ChatStreamingHelper.flatten_content(
|
||||
content=getattr(msg, "content", "")
|
||||
),
|
||||
}
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload["tool_calls"] = tool_calls
|
||||
name = getattr(msg, "name", None)
|
||||
if name:
|
||||
payload["name"] = name
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a dictionary-shaped message to a normalized dictionary.
|
||||
|
||||
This method handles messages that come from serialized state and are
|
||||
represented as dictionaries rather than BaseMessage instances. It
|
||||
normalizes various dictionary formats into a consistent structure.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-shaped message to convert. Expected to contain
|
||||
fields like "role", "type", "content", "text", etc.
|
||||
|
||||
Returns:
|
||||
A normalized dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> obj = {"type": "human", "content": "Hello"}
|
||||
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello'
|
||||
"""
|
||||
role: Optional[str] = obj.get("role")
|
||||
if not role:
|
||||
# Handle alternative type field mappings
|
||||
typ = obj.get("type")
|
||||
if typ in ("human", "user"):
|
||||
role = "user"
|
||||
elif typ in ("ai", "assistant"):
|
||||
role = "assistant"
|
||||
elif typ in ("system",):
|
||||
role = "system"
|
||||
elif typ in ("tool", "function"):
|
||||
role = "tool"
|
||||
|
||||
content = obj.get("content")
|
||||
if content is None and "text" in obj:
|
||||
content = obj["text"]
|
||||
|
||||
out: Dict[str, Any] = {
|
||||
"role": role or "assistant",
|
||||
"content": ChatStreamingHelper.flatten_content(content=content),
|
||||
}
|
||||
if "tool_calls" in obj:
|
||||
out["tool_calls"] = obj["tool_calls"]
|
||||
if obj.get("name"):
|
||||
out["name"] = obj["name"]
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
||||
"""Extract messages from LangGraph output objects.
|
||||
|
||||
This method handles various output formats from LangGraph execution,
|
||||
extracting the messages list from different possible structures.
|
||||
|
||||
Args:
|
||||
output_obj: The output object from LangGraph execution. Can be:
|
||||
- An object with a "messages" attribute
|
||||
- A dictionary with a "messages" key
|
||||
- Any other object (returns empty list)
|
||||
|
||||
Returns:
|
||||
A list of extracted messages, or an empty list if no messages
|
||||
are found or if the output object is None.
|
||||
|
||||
Examples:
|
||||
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
||||
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
||||
>>> len(messages)
|
||||
1
|
||||
"""
|
||||
if output_obj is None:
|
||||
return []
|
||||
|
||||
# Try to parse dicts first
|
||||
if isinstance(output_obj, dict):
|
||||
msgs = output_obj.get("messages")
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
|
||||
# Then try to get messages attribute
|
||||
msgs = getattr(output_obj, "messages", None)
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
1022
modules/features/chatBot/service.py
Normal file
1022
modules/features/chatBot/service.py
Normal file
File diff suppressed because it is too large
Load diff
106
modules/features/chatBot/utils/checkpointer.py
Normal file
106
modules/features/chatBot/utils/checkpointer.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""PostgreSQL checkpointer utilities for LangGraph memory."""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply)
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
except RuntimeError:
|
||||
pass # Already set
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.rows import dict_row
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpointer instance
|
||||
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
|
||||
_connection_pool: Optional[AsyncConnectionPool] = None
|
||||
|
||||
|
||||
async def initialize_checkpointer() -> None:
|
||||
"""Initialize the PostgreSQL checkpointer for LangGraph.
|
||||
|
||||
This should be called during application startup.
|
||||
Creates a connection pool and PostgresSaver instance.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _checkpointer_instance is not None:
|
||||
logger.info("Checkpointer already initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
# Get database configuration from environment
|
||||
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
|
||||
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
|
||||
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
|
||||
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
|
||||
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
|
||||
|
||||
# Build connection string
|
||||
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
# Create async connection pool
|
||||
_connection_pool = AsyncConnectionPool(
|
||||
conninfo=connection_string,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
kwargs={"autocommit": True, "row_factory": dict_row},
|
||||
)
|
||||
|
||||
# Initialize the connection pool
|
||||
await _connection_pool.open()
|
||||
|
||||
# Create AsyncPostgresSaver with the pool
|
||||
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
|
||||
|
||||
# Setup the checkpointer (creates tables if needed)
|
||||
await _checkpointer_instance.setup()
|
||||
|
||||
logger.info("PostgreSQL checkpointer initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""Close the checkpointer and connection pool.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _connection_pool is not None:
|
||||
try:
|
||||
await _connection_pool.close()
|
||||
logger.info("PostgreSQL checkpointer connection pool closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
|
||||
|
||||
_checkpointer_instance = None
|
||||
_connection_pool = None
|
||||
|
||||
|
||||
def get_checkpointer() -> AsyncPostgresSaver:
|
||||
"""Get the global PostgreSQL checkpointer instance.
|
||||
|
||||
Returns:
|
||||
The initialized AsyncPostgresSaver instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If checkpointer is not initialized
|
||||
"""
|
||||
if _checkpointer_instance is None:
|
||||
raise RuntimeError(
|
||||
"PostgreSQL checkpointer not initialized. "
|
||||
"Call initialize_checkpointer() during application startup."
|
||||
)
|
||||
return _checkpointer_instance
|
||||
39
modules/features/chatBot/utils/permissions.py
Normal file
39
modules/features/chatBot/utils/permissions.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Mock permissions module for chatbot access control.
|
||||
|
||||
This module provides mock permission functions that will be replaced
|
||||
with actual database-driven permissions in the future.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
|
||||
|
||||
# TODO: Replace these mock implementations with actual database queries
|
||||
|
||||
|
||||
def get_chatbot_tools(*, user_id: str) -> list[str]:
|
||||
"""Get list of tool IDs that the chatbot can use for a given user."""
|
||||
registry = get_registry()
|
||||
return registry.list_tool_ids()
|
||||
|
||||
|
||||
def get_chatbot_model(*, user_id: str) -> str:
|
||||
"""Gets the chatbot model(s) a user is allowed to use."""
|
||||
return "claude_4_5"
|
||||
|
||||
|
||||
def get_system_prompt(*, user_id: str) -> str:
|
||||
"""Get the system prompt for a user's chatbot session.
|
||||
|
||||
This is a mock implementation that returns a generic prompt with today's date.
|
||||
In production, this will query the database for user-specific or role-specific prompts.
|
||||
|
||||
Args:
|
||||
user_id: The unique identifier of the user
|
||||
|
||||
Returns:
|
||||
The system prompt string with the current date
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
return f"You're a smart assistant. Today is {current_date}"
|
||||
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
305
modules/features/chatBot/utils/toolRegistry.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Tool registry for auto-discovering and managing chatbot tools.
|
||||
|
||||
This module provides a central registry that automatically discovers all tools
|
||||
in the chatbotTools directory structure and provides methods to query them.
|
||||
The registry is built in-memory at startup and does not require a database.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""Metadata about a discovered chatbot tool.
|
||||
|
||||
Attributes:
|
||||
tool_id: Unique identifier (e.g., 'shared.tavily_search')
|
||||
name: Function name of the tool
|
||||
category: Category of the tool ('shared' or 'customer')
|
||||
description: Tool description from docstring
|
||||
tool_instance: The actual LangChain tool instance
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
|
||||
tool_id: str
|
||||
name: str
|
||||
category: str
|
||||
description: str
|
||||
tool_instance: BaseTool
|
||||
module_path: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a pretty-printed string representation for logging."""
|
||||
return (
|
||||
f"ToolMetadata(\n"
|
||||
f" tool_id='{self.tool_id}',\n"
|
||||
f" name='{self.name}',\n"
|
||||
f" category='{self.category}',\n"
|
||||
f" description='{self.description}',\n"
|
||||
f" module_path='{self.module_path}'\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Central registry for all chatbot tools.
|
||||
|
||||
This class discovers and catalogs all tools decorated with @tool in the
|
||||
chatbotTools directory structure. Tools are automatically discovered at
|
||||
initialization by scanning the filesystem.
|
||||
|
||||
The registry provides methods to query tools by ID, category, or get all tools.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty tool registry."""
|
||||
self._tools: Dict[str, ToolMetadata] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Discover and register all tools from the chatbotTools directory.
|
||||
|
||||
This method scans both sharedTools and customerTools directories,
|
||||
imports all tool*.py modules, and extracts functions decorated with @tool.
|
||||
|
||||
This method is idempotent - calling it multiple times has no effect
|
||||
after the first initialization.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("Tool registry already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.info("Initializing tool registry...")
|
||||
|
||||
# Get base path to chatbotTools directory
|
||||
base_path = Path(__file__).parent.parent / "chatbotTools"
|
||||
|
||||
if not base_path.exists():
|
||||
logger.warning(f"chatbotTools directory not found at {base_path}")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
# Discover tools in each category
|
||||
self._discover_category(
|
||||
category_path=base_path / "sharedTools", category="shared"
|
||||
)
|
||||
self._discover_category(
|
||||
category_path=base_path / "customerTools", category="customer"
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"Tool registry initialized with {len(self._tools)} tools")
|
||||
|
||||
def _discover_category(self, *, category_path: Path, category: str) -> None:
|
||||
"""Discover all tools in a specific category directory.
|
||||
|
||||
Args:
|
||||
category_path: Path to the category directory (sharedTools or customerTools)
|
||||
category: Category name ('shared' or 'customer')
|
||||
"""
|
||||
if not category_path.exists():
|
||||
logger.warning(f"Category directory not found: {category_path}")
|
||||
return
|
||||
|
||||
logger.debug(f"Discovering tools in category: {category}")
|
||||
|
||||
# Find all tool*.py files (excluding __init__.py)
|
||||
tool_files = [
|
||||
f for f in category_path.glob("tool*.py") if f.name != "__init__.py"
|
||||
]
|
||||
|
||||
for tool_file in tool_files:
|
||||
self._import_and_register_tools(
|
||||
tool_file=tool_file, category=category, category_path=category_path
|
||||
)
|
||||
|
||||
logger.debug(f"Discovered {len(tool_files)} tool files in {category}")
|
||||
|
||||
def _import_and_register_tools(
|
||||
self, *, tool_file: Path, category: str, category_path: Path
|
||||
) -> None:
|
||||
"""Import a tool module and register all discovered tools.
|
||||
|
||||
Args:
|
||||
tool_file: Path to the tool Python file
|
||||
category: Category name ('shared' or 'customer')
|
||||
category_path: Path to the category directory
|
||||
"""
|
||||
# Construct module name
|
||||
module_name = (
|
||||
f"modules.features.chatBot.chatbotTools.{category}Tools.{tool_file.stem}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Find all BaseTool instances in the module
|
||||
tools_found = 0
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if isinstance(obj, BaseTool):
|
||||
self._register_tool(
|
||||
tool_instance=obj,
|
||||
name=name,
|
||||
category=category,
|
||||
module_path=module_name,
|
||||
)
|
||||
tools_found += 1
|
||||
|
||||
if tools_found == 0:
|
||||
logger.warning(f"No tools found in {module_name}")
|
||||
else:
|
||||
logger.debug(f"Loaded {tools_found} tool(s) from {module_name}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"Import error loading tools from {module_name}: {str(e)}. "
|
||||
f"This tool will not be available."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error loading tools from {module_name}: {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
|
||||
def _register_tool(
|
||||
self, *, tool_instance: BaseTool, name: str, category: str, module_path: str
|
||||
) -> None:
|
||||
"""Register a single tool in the registry.
|
||||
|
||||
Args:
|
||||
tool_instance: The LangChain tool instance
|
||||
name: Function name of the tool
|
||||
category: Category name ('shared' or 'customer')
|
||||
module_path: Full Python module path
|
||||
"""
|
||||
tool_id = f"{category}.{name}"
|
||||
|
||||
# Check for duplicate tool IDs
|
||||
if tool_id in self._tools:
|
||||
logger.warning(f"Duplicate tool ID detected: {tool_id}, overwriting")
|
||||
|
||||
metadata = ToolMetadata(
|
||||
tool_id=tool_id,
|
||||
name=name,
|
||||
category=category,
|
||||
description=tool_instance.description or "",
|
||||
tool_instance=tool_instance,
|
||||
module_path=module_path,
|
||||
)
|
||||
|
||||
self._tools[tool_id] = metadata
|
||||
logger.debug(f"Registered tool: {tool_id}")
|
||||
|
||||
def get_all_tools(self) -> List[ToolMetadata]:
|
||||
"""Get all registered tools.
|
||||
|
||||
Returns:
|
||||
List of all tool metadata objects
|
||||
"""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_tool(self, *, tool_id: str) -> Optional[ToolMetadata]:
|
||||
"""Get a specific tool by its ID.
|
||||
|
||||
Args:
|
||||
tool_id: The tool identifier (e.g., 'shared.tavily_search')
|
||||
|
||||
Returns:
|
||||
Tool metadata if found, None otherwise
|
||||
"""
|
||||
return self._tools.get(tool_id)
|
||||
|
||||
def get_tools_by_category(self, *, category: str) -> List[ToolMetadata]:
|
||||
"""Get all tools in a specific category.
|
||||
|
||||
Args:
|
||||
category: Category name ('shared' or 'customer')
|
||||
|
||||
Returns:
|
||||
List of tool metadata for the specified category
|
||||
"""
|
||||
return [t for t in self._tools.values() if t.category == category]
|
||||
|
||||
def list_tool_ids(self) -> List[str]:
|
||||
"""Get a list of all registered tool IDs.
|
||||
|
||||
Returns:
|
||||
List of tool ID strings
|
||||
"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def get_tool_instances(self, *, tool_ids: List[str]) -> List[BaseTool]:
|
||||
"""Get actual tool instances for a list of tool IDs.
|
||||
|
||||
This is useful for filtering tools based on user permissions.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs to retrieve
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances for the specified IDs
|
||||
"""
|
||||
instances = []
|
||||
for tool_id in tool_ids:
|
||||
metadata = self.get_tool(tool_id=tool_id)
|
||||
if metadata:
|
||||
instances.append(metadata.tool_instance)
|
||||
else:
|
||||
logger.warning(f"Tool ID not found in registry: {tool_id}")
|
||||
return instances
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the registry has been initialized.
|
||||
|
||||
Returns:
|
||||
True if initialized, False otherwise
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ToolRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global tool registry instance.
|
||||
|
||||
This function ensures the registry is initialized on first access.
|
||||
Subsequent calls return the same instance.
|
||||
|
||||
Returns:
|
||||
The global ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
|
||||
if _registry is None:
|
||||
_registry = ToolRegistry()
|
||||
|
||||
if not _registry.is_initialized:
|
||||
_registry.initialize()
|
||||
|
||||
return _registry
|
||||
|
||||
|
||||
def reinitialize_registry() -> ToolRegistry:
|
||||
"""Force reinitialize the tool registry.
|
||||
|
||||
This is useful for testing or when tools are added dynamically.
|
||||
|
||||
Returns:
|
||||
The reinitialized ToolRegistry instance
|
||||
"""
|
||||
global _registry
|
||||
_registry = ToolRegistry()
|
||||
_registry.initialize()
|
||||
return _registry
|
||||
|
|
@ -18,11 +18,19 @@ from modules.shared.configuration import APP_CONFIG
|
|||
from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp
|
||||
from modules.interfaces.interfaceDbAppAccess import AppAccess
|
||||
from modules.datamodels.datamodelUam import (
|
||||
User, Mandate, UserInDB, UserConnection,
|
||||
AuthAuthority, UserPrivilege, ConnectionStatus,
|
||||
User,
|
||||
Mandate,
|
||||
UserInDB,
|
||||
UserConnection,
|
||||
AuthAuthority,
|
||||
UserPrivilege,
|
||||
ConnectionStatus,
|
||||
)
|
||||
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus
|
||||
from modules.datamodels.datamodelNeutralizer import DataNeutraliserConfig, DataNeutralizerAttributes
|
||||
from modules.datamodels.datamodelNeutralizer import (
|
||||
DataNeutraliserConfig,
|
||||
DataNeutralizerAttributes,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,6 +43,7 @@ _rootAppObjects = None
|
|||
# Password-Hashing
|
||||
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
class AppObjects:
|
||||
"""
|
||||
Interface to the Gateway system.
|
||||
|
|
@ -76,14 +85,16 @@ class AppObjects:
|
|||
self.userLanguage = currentUser.language # Default user language
|
||||
|
||||
# Initialize access control with user context
|
||||
self.access = AppAccess(self.currentUser, self.db) # Convert to dict only when needed
|
||||
self.access = AppAccess(
|
||||
self.currentUser, self.db
|
||||
) # Convert to dict only when needed
|
||||
|
||||
# Update database context
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to close database connection."""
|
||||
if hasattr(self, 'db') and self.db is not None:
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
try:
|
||||
self.db.close()
|
||||
except Exception as e:
|
||||
|
|
@ -106,7 +117,7 @@ class AppObjects:
|
|||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.userId
|
||||
userId=self.userId,
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
|
|
@ -129,16 +140,12 @@ class AppObjects:
|
|||
mandates = self.db.getRecordset(Mandate)
|
||||
if existingMandateId is None or not mandates:
|
||||
logger.info("Creating Root mandate")
|
||||
rootMandate = Mandate(
|
||||
name="Root",
|
||||
language="en",
|
||||
enabled=True
|
||||
)
|
||||
rootMandate = Mandate(name="Root", language="en", enabled=True)
|
||||
createdMandate = self.db.recordCreate(Mandate, rootMandate)
|
||||
logger.info(f"Root mandate created with ID {createdMandate['id']}")
|
||||
|
||||
# Update mandate context
|
||||
self.mandateId = createdMandate['id']
|
||||
self.mandateId = createdMandate["id"]
|
||||
|
||||
def _initAdminUser(self):
|
||||
"""Creates the Admin user if it doesn't exist."""
|
||||
|
|
@ -155,8 +162,10 @@ class AppObjects:
|
|||
language="en",
|
||||
privilege=UserPrivilege.SYSADMIN,
|
||||
authenticationAuthority="local", # Using lowercase value directly
|
||||
hashedPassword=self._getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")),
|
||||
connections=[]
|
||||
hashedPassword=self._getPasswordHash(
|
||||
APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")
|
||||
),
|
||||
connections=[],
|
||||
)
|
||||
createdUser = self.db.recordCreate(UserInDB, adminUser)
|
||||
logger.info(f"Admin user created with ID {createdUser['id']}")
|
||||
|
|
@ -168,7 +177,9 @@ class AppObjects:
|
|||
def _initEventUser(self):
|
||||
"""Creates the Event user if it doesn't exist."""
|
||||
# Check if event user already exists
|
||||
existingUsers = self.db.getRecordset(UserInDB, recordFilter={"username": "event"})
|
||||
existingUsers = self.db.getRecordset(
|
||||
UserInDB, recordFilter={"username": "event"}
|
||||
)
|
||||
if not existingUsers:
|
||||
logger.info("Creating Event user")
|
||||
eventUser = UserInDB(
|
||||
|
|
@ -180,13 +191,17 @@ class AppObjects:
|
|||
language="en",
|
||||
privilege=UserPrivilege.SYSADMIN,
|
||||
authenticationAuthority="local", # Using lowercase value directly
|
||||
hashedPassword=self._getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")),
|
||||
connections=[]
|
||||
hashedPassword=self._getPasswordHash(
|
||||
APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")
|
||||
),
|
||||
connections=[],
|
||||
)
|
||||
createdUser = self.db.recordCreate(UserInDB, eventUser)
|
||||
logger.info(f"Event user created with ID {createdUser['id']}")
|
||||
|
||||
def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def _uam(
|
||||
self, model_class: type, recordset: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified user access management function that filters data based on user privileges
|
||||
and adds access control attributes.
|
||||
|
|
@ -205,7 +220,7 @@ class AppObjects:
|
|||
cleanedRecords = []
|
||||
for record in filteredRecords:
|
||||
# Create a new dict with only non-database fields
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith('_')}
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
cleanedRecords.append(cleanedRecord)
|
||||
|
||||
return cleanedRecords
|
||||
|
|
@ -317,12 +332,20 @@ class AppObjects:
|
|||
|
||||
return user
|
||||
|
||||
def createUser(self, username: str, password: str = None, email: str = None,
|
||||
fullName: str = None, language: str = "en", enabled: bool = True,
|
||||
def createUser(
|
||||
self,
|
||||
username: str,
|
||||
password: str = None,
|
||||
email: str = None,
|
||||
fullName: str = None,
|
||||
language: str = "en",
|
||||
enabled: bool = True,
|
||||
privilege: UserPrivilege = UserPrivilege.USER,
|
||||
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
|
||||
externalId: str = None, externalUsername: str = None,
|
||||
externalEmail: str = None) -> User:
|
||||
externalId: str = None,
|
||||
externalUsername: str = None,
|
||||
externalEmail: str = None,
|
||||
) -> User:
|
||||
"""Create a new user with optional external connection"""
|
||||
try:
|
||||
# Ensure username is a string
|
||||
|
|
@ -348,7 +371,7 @@ class AppObjects:
|
|||
privilege=privilege,
|
||||
authenticationAuthority=authenticationAuthority,
|
||||
hashedPassword=self._getPasswordHash(password) if password else None,
|
||||
connections=[]
|
||||
connections=[],
|
||||
)
|
||||
|
||||
# Create user record
|
||||
|
|
@ -356,7 +379,6 @@ class AppObjects:
|
|||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create user record")
|
||||
|
||||
|
||||
# Add external connection if provided
|
||||
if externalId and externalUsername:
|
||||
self.addUserConnection(
|
||||
|
|
@ -364,11 +386,13 @@ class AppObjects:
|
|||
authenticationAuthority,
|
||||
externalId,
|
||||
externalUsername,
|
||||
externalEmail
|
||||
externalEmail,
|
||||
)
|
||||
|
||||
# Get created user using the returned ID
|
||||
createdUser = self.db.getRecordset(UserInDB, recordFilter={"id": createdRecord["id"]})
|
||||
createdUser = self.db.getRecordset(
|
||||
UserInDB, recordFilter={"id": createdRecord["id"]}
|
||||
)
|
||||
if not createdUser or len(createdUser) == 0:
|
||||
raise ValueError("Failed to retrieve created user")
|
||||
|
||||
|
|
@ -399,7 +423,6 @@ class AppObjects:
|
|||
# Update user record
|
||||
self.db.recordModify(UserInDB, userId, updatedUser)
|
||||
|
||||
|
||||
# Get updated user
|
||||
updatedUser = self.getUser(userId)
|
||||
if not updatedUser:
|
||||
|
|
@ -422,8 +445,6 @@ class AppObjects:
|
|||
def _deleteUserReferencedData(self, userId: str) -> None:
|
||||
"""Deletes all data associated with a user."""
|
||||
try:
|
||||
|
||||
|
||||
# Delete user auth events
|
||||
events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
|
||||
for event in events:
|
||||
|
|
@ -434,9 +455,10 @@ class AppObjects:
|
|||
for token in tokens:
|
||||
self.db.recordDelete(Token, token["id"])
|
||||
|
||||
|
||||
# Delete user connections
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
|
||||
connections = self.db.getRecordset(
|
||||
UserConnection, recordFilter={"userId": userId}
|
||||
)
|
||||
for conn in connections:
|
||||
self.db.recordDelete(UserConnection, conn["id"])
|
||||
|
||||
|
|
@ -465,7 +487,6 @@ class AppObjects:
|
|||
if not success:
|
||||
raise ValueError(f"Failed to delete user {userId}")
|
||||
|
||||
|
||||
logger.info(f"User {userId} successfully deleted")
|
||||
return True
|
||||
|
||||
|
|
@ -493,31 +514,22 @@ class AppObjects:
|
|||
authenticationAuthority = checkData.get("authenticationAuthority", "local")
|
||||
|
||||
if not username:
|
||||
return {
|
||||
"available": False,
|
||||
"message": "Username is required"
|
||||
}
|
||||
return {"available": False, "message": "Username is required"}
|
||||
|
||||
# Get user by username
|
||||
user = self.getUserByUsername(username)
|
||||
|
||||
# Check if user exists (User model instance)
|
||||
if user is not None:
|
||||
return {
|
||||
"available": False,
|
||||
"message": "Username is already taken"
|
||||
}
|
||||
return {"available": False, "message": "Username is already taken"}
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"message": "Username is available"
|
||||
}
|
||||
return {"available": True, "message": "Username is available"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking username availability: {str(e)}")
|
||||
return {
|
||||
"available": False,
|
||||
"message": f"Error checking username availability: {str(e)}"
|
||||
"message": f"Error checking username availability: {str(e)}",
|
||||
}
|
||||
|
||||
# Connection methods
|
||||
|
|
@ -526,7 +538,9 @@ class AppObjects:
|
|||
"""Returns all connections for a user."""
|
||||
try:
|
||||
# Get connections for this user
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
|
||||
connections = self.db.getRecordset(
|
||||
UserConnection, recordFilter={"userId": userId}
|
||||
)
|
||||
|
||||
# Convert to UserConnection objects
|
||||
result = []
|
||||
|
|
@ -543,11 +557,13 @@ class AppObjects:
|
|||
status=conn_dict.get("status", "pending"),
|
||||
connectedAt=conn_dict.get("connectedAt"),
|
||||
lastChecked=conn_dict.get("lastChecked"),
|
||||
expiresAt=conn_dict.get("expiresAt")
|
||||
expiresAt=conn_dict.get("expiresAt"),
|
||||
)
|
||||
result.append(connection)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting connection dict to object: {str(e)}")
|
||||
logger.error(
|
||||
f"Error converting connection dict to object: {str(e)}"
|
||||
)
|
||||
continue
|
||||
return result
|
||||
|
||||
|
|
@ -555,9 +571,15 @@ class AppObjects:
|
|||
logger.error(f"Error getting user connections: {str(e)}")
|
||||
return []
|
||||
|
||||
def addUserConnection(self, userId: str, authority: AuthAuthority, externalId: str,
|
||||
externalUsername: str, externalEmail: Optional[str] = None,
|
||||
status: ConnectionStatus = ConnectionStatus.PENDING) -> UserConnection:
|
||||
def addUserConnection(
|
||||
self,
|
||||
userId: str,
|
||||
authority: AuthAuthority,
|
||||
externalId: str,
|
||||
externalUsername: str,
|
||||
externalEmail: Optional[str] = None,
|
||||
status: ConnectionStatus = ConnectionStatus.PENDING,
|
||||
) -> UserConnection:
|
||||
"""
|
||||
Adds a new connection for a user.
|
||||
|
||||
|
|
@ -589,13 +611,12 @@ class AppObjects:
|
|||
status=status,
|
||||
connectedAt=get_utc_timestamp(),
|
||||
lastChecked=get_utc_timestamp(),
|
||||
expiresAt=None # Optional field, set to None by default
|
||||
expiresAt=None, # Optional field, set to None by default
|
||||
)
|
||||
|
||||
# Save to connections table
|
||||
self.db.recordCreate(UserConnection, connection)
|
||||
|
||||
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -606,9 +627,9 @@ class AppObjects:
|
|||
"""Remove a connection to an external service"""
|
||||
try:
|
||||
# Get connection
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={
|
||||
"id": connectionId
|
||||
})
|
||||
connections = self.db.getRecordset(
|
||||
UserConnection, recordFilter={"id": connectionId}
|
||||
)
|
||||
|
||||
if not connections:
|
||||
raise ValueError(f"Connection {connectionId} not found")
|
||||
|
|
@ -616,7 +637,6 @@ class AppObjects:
|
|||
# Delete connection
|
||||
self.db.recordDelete(UserConnection, connectionId)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing user connection: {str(e)}")
|
||||
raise ValueError(f"Failed to remove user connection: {str(e)}")
|
||||
|
|
@ -647,17 +667,13 @@ class AppObjects:
|
|||
raise PermissionError("No permission to create mandates")
|
||||
|
||||
# Create mandate data using model
|
||||
mandateData = Mandate(
|
||||
name=name,
|
||||
language=language
|
||||
)
|
||||
mandateData = Mandate(name=name, language=language)
|
||||
|
||||
# Create mandate record
|
||||
createdRecord = self.db.recordCreate(Mandate, mandateData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create mandate record")
|
||||
|
||||
|
||||
return Mandate.from_dict(createdRecord)
|
||||
|
||||
def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate:
|
||||
|
|
@ -707,7 +723,9 @@ class AppObjects:
|
|||
# Check if mandate has users
|
||||
users = self.getUsersByMandate(mandateId)
|
||||
if users:
|
||||
raise ValueError(f"Cannot delete mandate {mandateId} with existing users")
|
||||
raise ValueError(
|
||||
f"Cannot delete mandate {mandateId} with existing users"
|
||||
)
|
||||
|
||||
# Delete mandate
|
||||
success = self.db.recordDelete(Mandate, mandateId)
|
||||
|
|
@ -727,7 +745,9 @@ class AppObjects:
|
|||
try:
|
||||
# Validate that this is NOT a connection token
|
||||
if token.connectionId:
|
||||
raise ValueError("Access tokens cannot have connectionId - use saveConnectionToken instead")
|
||||
raise ValueError(
|
||||
"Access tokens cannot have connectionId - use saveConnectionToken instead"
|
||||
)
|
||||
|
||||
# Validate user context
|
||||
if not self.currentUser or not self.currentUser.id:
|
||||
|
|
@ -745,33 +765,44 @@ class AppObjects:
|
|||
# If replace_existing is True, delete old access tokens for this user and authority first
|
||||
if replace_existing:
|
||||
try:
|
||||
old_tokens = self.db.getRecordset(Token, recordFilter={
|
||||
old_tokens = self.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={
|
||||
"userId": self.currentUser.id,
|
||||
"authority": token.authority,
|
||||
"connectionId": None # Ensure we only delete access tokens
|
||||
})
|
||||
"connectionId": None, # Ensure we only delete access tokens
|
||||
},
|
||||
)
|
||||
deleted_count = 0
|
||||
for old_token in old_tokens:
|
||||
if old_token["id"] != token.id: # Don't delete the new token if it already exists
|
||||
if (
|
||||
old_token["id"] != token.id
|
||||
): # Don't delete the new token if it already exists
|
||||
self.db.recordDelete(Token, old_token["id"])
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}")
|
||||
logger.info(
|
||||
f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}")
|
||||
logger.warning(
|
||||
f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}"
|
||||
)
|
||||
# Continue with saving the new token even if deletion fails
|
||||
|
||||
# Convert to dict and ensure all fields are properly set
|
||||
token_dict = token.dict()
|
||||
token_dict = token.model_dump()
|
||||
# Ensure userId is set to current user
|
||||
# Convert to dict and ensure all fields are properly set
|
||||
token_dict = token.model_dump()
|
||||
# Ensure userId is set to current user
|
||||
token_dict["userId"] = self.currentUser.id
|
||||
|
||||
# Save to database
|
||||
self.db.recordCreate(Token, token_dict)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving access token: {str(e)}")
|
||||
raise
|
||||
|
|
@ -781,7 +812,9 @@ class AppObjects:
|
|||
try:
|
||||
# Validate that this IS a connection token
|
||||
if not token.connectionId:
|
||||
raise ValueError("Connection tokens must have connectionId - use saveAccessToken instead")
|
||||
raise ValueError(
|
||||
"Connection tokens must have connectionId - use saveAccessToken instead"
|
||||
)
|
||||
|
||||
# Validate user context
|
||||
if not self.currentUser or not self.currentUser.id:
|
||||
|
|
@ -799,31 +832,36 @@ class AppObjects:
|
|||
# If replace_existing is True, delete old tokens for this connectionId first
|
||||
if replace_existing:
|
||||
try:
|
||||
old_tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": token.connectionId
|
||||
})
|
||||
old_tokens = self.db.getRecordset(
|
||||
Token, recordFilter={"connectionId": token.connectionId}
|
||||
)
|
||||
deleted_count = 0
|
||||
for old_token in old_tokens:
|
||||
if old_token["id"] != token.id: # Don't delete the new token if it already exists
|
||||
if (
|
||||
old_token["id"] != token.id
|
||||
): # Don't delete the new token if it already exists
|
||||
self.db.recordDelete(Token, old_token["id"])
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}")
|
||||
logger.info(
|
||||
f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}")
|
||||
logger.warning(
|
||||
f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}"
|
||||
)
|
||||
# Continue with saving the new token even if deletion fails
|
||||
|
||||
# Convert to dict and ensure all fields are properly set
|
||||
token_dict = token.dict()
|
||||
token_dict = token.model_dump()
|
||||
# Ensure userId is set to current user
|
||||
token_dict["userId"] = self.currentUser.id
|
||||
|
||||
# Save to database
|
||||
self.db.recordCreate(Token, token_dict)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving connection token: {str(e)}")
|
||||
raise
|
||||
|
|
@ -837,13 +875,14 @@ class AppObjects:
|
|||
|
||||
# Get token for this specific connection
|
||||
# Query for specific connection
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": connectionId
|
||||
})
|
||||
|
||||
tokens = self.db.getRecordset(
|
||||
Token, recordFilter={"connectionId": connectionId}
|
||||
)
|
||||
|
||||
if not tokens:
|
||||
logger.warning(f"No connection token found for connectionId: {connectionId}")
|
||||
logger.warning(
|
||||
f"No connection token found for connectionId: {connectionId}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Sort by expiration date and get the latest (most recent expiration)
|
||||
|
|
@ -855,16 +894,27 @@ class AppObjects:
|
|||
return latest_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connection token for connectionId {connectionId}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting connection token for connectionId {connectionId}: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def findActiveTokenById(self, tokenId: str, userId: str, authority: AuthAuthority, sessionId: str = None, mandateId: str = None) -> Optional[Token]:
|
||||
def findActiveTokenById(
|
||||
self,
|
||||
tokenId: str,
|
||||
userId: str,
|
||||
authority: AuthAuthority,
|
||||
sessionId: str = None,
|
||||
mandateId: str = None,
|
||||
) -> Optional[Token]:
|
||||
"""Find an active access token by its id (jti) with optional session/tenant scoping."""
|
||||
try:
|
||||
recordFilter = {
|
||||
"id": tokenId,
|
||||
"userId": userId,
|
||||
"authority": authority.value if hasattr(authority, 'value') else str(authority),
|
||||
"authority": authority.value
|
||||
if hasattr(authority, "value")
|
||||
else str(authority),
|
||||
"status": TokenStatus.ACTIVE,
|
||||
}
|
||||
if sessionId is not None:
|
||||
|
|
@ -892,7 +942,7 @@ class AppObjects:
|
|||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "revoked"
|
||||
"reason": reason or "revoked",
|
||||
}
|
||||
self.db.recordModify(Token, tokenId, tokenUpdate)
|
||||
return True
|
||||
|
|
@ -900,30 +950,53 @@ class AppObjects:
|
|||
logger.error(f"Error revoking token {tokenId}: {str(e)}")
|
||||
return False
|
||||
|
||||
def revokeTokensBySessionId(self, sessionId: str, userId: str, authority: AuthAuthority, revokedBy: str, reason: str = None) -> int:
|
||||
def revokeTokensBySessionId(
|
||||
self,
|
||||
sessionId: str,
|
||||
userId: str,
|
||||
authority: AuthAuthority,
|
||||
revokedBy: str,
|
||||
reason: str = None,
|
||||
) -> int:
|
||||
"""Revoke all tokens of a session for a user/authority."""
|
||||
try:
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
tokens = self.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={
|
||||
"userId": userId,
|
||||
"authority": authority.value if hasattr(authority, 'value') else str(authority),
|
||||
"authority": authority.value
|
||||
if hasattr(authority, "value")
|
||||
else str(authority),
|
||||
"sessionId": sessionId,
|
||||
"status": TokenStatus.ACTIVE
|
||||
})
|
||||
"status": TokenStatus.ACTIVE,
|
||||
},
|
||||
)
|
||||
count = 0
|
||||
for t in tokens:
|
||||
self.db.recordModify(Token, t["id"], {
|
||||
self.db.recordModify(
|
||||
Token,
|
||||
t["id"],
|
||||
{
|
||||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "session logout"
|
||||
})
|
||||
"reason": reason or "session logout",
|
||||
},
|
||||
)
|
||||
count += 1
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking tokens for session {sessionId}: {str(e)}")
|
||||
return 0
|
||||
|
||||
def revokeTokensByUser(self, userId: str, authority: AuthAuthority = None, mandateId: str = None, revokedBy: str = None, reason: str = None) -> int:
|
||||
def revokeTokensByUser(
|
||||
self,
|
||||
userId: str,
|
||||
authority: AuthAuthority = None,
|
||||
mandateId: str = None,
|
||||
revokedBy: str = None,
|
||||
reason: str = None,
|
||||
) -> int:
|
||||
"""Revoke all active tokens for a user, optionally filtered by authority/mandate."""
|
||||
try:
|
||||
# Fetch all active tokens for user (optionally filtered by authority)
|
||||
|
|
@ -932,16 +1005,22 @@ class AppObjects:
|
|||
"status": TokenStatus.ACTIVE,
|
||||
}
|
||||
if authority is not None:
|
||||
recordFilter["authority"] = authority.value if hasattr(authority, 'value') else str(authority)
|
||||
recordFilter["authority"] = (
|
||||
authority.value if hasattr(authority, "value") else str(authority)
|
||||
)
|
||||
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
||||
count = 0
|
||||
for t in tokens:
|
||||
self.db.recordModify(Token, t["id"], {
|
||||
self.db.recordModify(
|
||||
Token,
|
||||
t["id"],
|
||||
{
|
||||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "admin revoke"
|
||||
})
|
||||
"reason": reason or "admin revoke",
|
||||
},
|
||||
)
|
||||
count += 1
|
||||
return count
|
||||
except Exception as e:
|
||||
|
|
@ -958,7 +1037,10 @@ class AppObjects:
|
|||
all_tokens = self.db.getRecordset(Token, recordFilter={})
|
||||
|
||||
for token_data in all_tokens:
|
||||
if token_data.get("expiresAt") and token_data.get("expiresAt") < current_time:
|
||||
if (
|
||||
token_data.get("expiresAt")
|
||||
and token_data.get("expiresAt") < current_time
|
||||
):
|
||||
# Token is expired, delete it
|
||||
self.db.recordDelete(Token, token_data["id"])
|
||||
cleaned_count += 1
|
||||
|
|
@ -983,7 +1065,7 @@ class AppObjects:
|
|||
self.access = None
|
||||
|
||||
# Clear database context
|
||||
if hasattr(self, 'db'):
|
||||
if hasattr(self, "db"):
|
||||
self.db.updateContext("")
|
||||
|
||||
logger.info("User logged out successfully")
|
||||
|
|
@ -997,7 +1079,9 @@ class AppObjects:
|
|||
def getNeutralizationConfig(self) -> Optional[DataNeutraliserConfig]:
|
||||
"""Get the data neutralization configuration for the current user's mandate"""
|
||||
try:
|
||||
configs = self.db.getRecordset(DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId})
|
||||
configs = self.db.getRecordset(
|
||||
DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId}
|
||||
)
|
||||
if not configs:
|
||||
return None
|
||||
|
||||
|
|
@ -1012,7 +1096,9 @@ class AppObjects:
|
|||
logger.error(f"Error getting neutralization config: {str(e)}")
|
||||
return None
|
||||
|
||||
def createOrUpdateNeutralizationConfig(self, config_data: Dict[str, Any]) -> DataNeutraliserConfig:
|
||||
def createOrUpdateNeutralizationConfig(
|
||||
self, config_data: Dict[str, Any]
|
||||
) -> DataNeutraliserConfig:
|
||||
"""Create or update the data neutralization configuration"""
|
||||
try:
|
||||
# Check if config already exists
|
||||
|
|
@ -1025,7 +1111,9 @@ class AppObjects:
|
|||
update_data["updatedAt"] = get_utc_timestamp()
|
||||
|
||||
updated_config = DataNeutraliserConfig.from_dict(update_data)
|
||||
self.db.recordModify(DataNeutraliserConfig, existing_config.id, updated_config)
|
||||
self.db.recordModify(
|
||||
DataNeutraliserConfig, existing_config.id, updated_config
|
||||
)
|
||||
|
||||
return updated_config
|
||||
else:
|
||||
|
|
@ -1042,17 +1130,24 @@ class AppObjects:
|
|||
logger.error(f"Error creating/updating neutralization config: {str(e)}")
|
||||
raise ValueError(f"Failed to create/update neutralization config: {str(e)}")
|
||||
|
||||
def getNeutralizationAttributes(self, file_id: Optional[str] = None) -> List[DataNeutralizerAttributes]:
|
||||
def getNeutralizationAttributes(
|
||||
self, file_id: Optional[str] = None
|
||||
) -> List[DataNeutralizerAttributes]:
|
||||
"""Get neutralization attributes, optionally filtered by file ID"""
|
||||
try:
|
||||
filter_dict = {"mandateId": self.mandateId}
|
||||
if file_id:
|
||||
filter_dict["fileId"] = file_id
|
||||
|
||||
attributes = self.db.getRecordset(DataNeutralizerAttributes, recordFilter=filter_dict)
|
||||
attributes = self.db.getRecordset(
|
||||
DataNeutralizerAttributes, recordFilter=filter_dict
|
||||
)
|
||||
filtered_attributes = self._uam(DataNeutralizerAttributes, attributes)
|
||||
|
||||
return [DataNeutralizerAttributes.from_dict(attr) for attr in filtered_attributes]
|
||||
return [
|
||||
DataNeutralizerAttributes.from_dict(attr)
|
||||
for attr in filtered_attributes
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization attributes: {str(e)}")
|
||||
|
|
@ -1061,23 +1156,27 @@ class AppObjects:
|
|||
def deleteNeutralizationAttributes(self, file_id: str) -> bool:
|
||||
"""Delete all neutralization attributes for a specific file"""
|
||||
try:
|
||||
attributes = self.db.getRecordset(DataNeutralizerAttributes, recordFilter={
|
||||
"mandateId": self.mandateId,
|
||||
"fileId": file_id
|
||||
})
|
||||
attributes = self.db.getRecordset(
|
||||
DataNeutralizerAttributes,
|
||||
recordFilter={"mandateId": self.mandateId, "fileId": file_id},
|
||||
)
|
||||
|
||||
for attribute in attributes:
|
||||
self.db.recordDelete(DataNeutralizerAttributes, attribute["id"])
|
||||
|
||||
logger.info(f"Deleted {len(attributes)} neutralization attributes for file {file_id}")
|
||||
logger.info(
|
||||
f"Deleted {len(attributes)} neutralization attributes for file {file_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting neutralization attributes: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# Public Methods
|
||||
|
||||
|
||||
def getInterface(currentUser: User) -> AppObjects:
|
||||
"""
|
||||
Returns a AppObjects instance for the current user.
|
||||
|
|
@ -1095,6 +1194,7 @@ def getInterface(currentUser: User) -> AppObjects:
|
|||
|
||||
return _gatewayInterfaces[contextKey]
|
||||
|
||||
|
||||
def getRootInterface() -> AppObjects:
|
||||
"""
|
||||
Returns a AppObjects instance with root privileges.
|
||||
|
|
@ -1112,13 +1212,15 @@ def getRootInterface() -> AppObjects:
|
|||
if not initialUserId:
|
||||
raise ValueError("No initial user ID found in database")
|
||||
|
||||
users = tempInterface.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
|
||||
users = tempInterface.db.getRecordset(
|
||||
UserInDB, recordFilter={"id": initialUserId}
|
||||
)
|
||||
if not users:
|
||||
raise ValueError("Initial user not found in database")
|
||||
|
||||
# Convert to User model
|
||||
user_data = users[0]
|
||||
rootUser = User.parse_obj(user_data)
|
||||
rootUser = User.model_validate(user_data)
|
||||
|
||||
# Create root interface with the root user
|
||||
_rootAppObjects = AppObjects(rootUser)
|
||||
|
|
|
|||
|
|
@ -21,13 +21,14 @@ os.makedirs(staticFolder, exist_ok=True)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="",
|
||||
tags=["Administration"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
prefix="", tags=["Administration"], responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
# Mount static files
|
||||
router.mount("/static", StaticFiles(directory=str(staticFolder), html=True), name="static")
|
||||
router.mount(
|
||||
"/static", StaticFiles(directory=str(staticFolder), html=True), name="static"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@limiter.limit("30/minute")
|
||||
|
|
@ -36,22 +37,29 @@ async def root(request: Request) -> Dict[str, str]:
|
|||
# Validate required configuration values
|
||||
allowedOrigins = APP_CONFIG.get("APP_ALLOWED_ORIGINS")
|
||||
if not allowedOrigins:
|
||||
raise HTTPException(status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "online",
|
||||
"message": "Data Platform API is active",
|
||||
"allowedOrigins": f"Allowed origins are {allowedOrigins}"
|
||||
"allowedOrigins": f"Allowed origins are {allowedOrigins}",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/environment")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_environment(request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, str]:
|
||||
async def get_environment(
|
||||
request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Get environment configuration for frontend"""
|
||||
# Validate required configuration values
|
||||
apiBaseUrl = APP_CONFIG.get("APP_API_URL")
|
||||
if not apiBaseUrl:
|
||||
raise HTTPException(status_code=500, detail="APP_API_URL configuration is required")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_API_URL configuration is required"
|
||||
)
|
||||
|
||||
environment = APP_CONFIG.get("APP_ENV")
|
||||
if not environment:
|
||||
|
|
@ -59,7 +67,9 @@ async def get_environment(request: Request, currentUser: Dict[str, Any] = Depend
|
|||
|
||||
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||
if not instanceLabel:
|
||||
raise HTTPException(status_code=500, detail="APP_ENV_LABEL configuration is required")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_ENV_LABEL configuration is required"
|
||||
)
|
||||
|
||||
return {
|
||||
"apiBaseUrl": apiBaseUrl,
|
||||
|
|
@ -68,13 +78,17 @@ async def get_environment(request: Request, currentUser: Dict[str, Any] = Depend
|
|||
# Add other environment variables the frontend might need
|
||||
}
|
||||
|
||||
|
||||
@router.options("/{fullPath:path}")
|
||||
@limiter.limit("60/minute")
|
||||
async def options_route(request: Request, fullPath: str) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/favicon.ico")
|
||||
@limiter.limit("30/minute")
|
||||
async def favicon(request: Request) -> FileResponse:
|
||||
return FileResponse(str(staticFolder / "favicon.ico"), media_type="image/x-icon")
|
||||
|
||||
favicon_path = staticFolder / "favicon.ico"
|
||||
if not favicon_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Favicon not found")
|
||||
return FileResponse(str(favicon_path), media_type="image/x-icon")
|
||||
|
|
|
|||
655
modules/routes/routeChatbot.py
Normal file
655
modules/routes/routeChatbot.py
Normal file
|
|
@ -0,0 +1,655 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
from modules.features.chatBot.database import get_async_db_session
|
||||
from modules.features.chatBot.service import (
|
||||
get_or_create_thread_for_user,
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User, UserPrivilege
|
||||
from modules.datamodels.datamodelChatbot import (
|
||||
ChatMessageRequest,
|
||||
MessageItem,
|
||||
ChatMessageResponse,
|
||||
ThreadSummary,
|
||||
ThreadListResponse,
|
||||
ThreadDetail,
|
||||
RenameThreadRequest,
|
||||
DeleteResponse,
|
||||
ToolListResponse,
|
||||
ToolInfo,
|
||||
GrantToolRequest,
|
||||
GrantToolResponse,
|
||||
RevokeToolRequest,
|
||||
RevokeToolResponse,
|
||||
UpdateToolRequest,
|
||||
UpdateToolResponse,
|
||||
)
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.features.chatBot import service as chat_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/chatbot",
|
||||
tags=["Chatbot"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
# --- Actual endpoints for chatbot ---
|
||||
|
||||
|
||||
@router.post("/message/stream")
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message_stream(
|
||||
*,
|
||||
request: Request,
|
||||
message_request: ChatMessageRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Post a message to a chat thread with streaming progress updates.
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
chat_service.post_message_stream(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/message", response_model=ChatMessageResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message(
|
||||
*,
|
||||
request: Request,
|
||||
message_request: ChatMessageRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ChatMessageResponse:
|
||||
"""
|
||||
Post a message to a chat thread and get assistant response (non-streaming).
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
For streaming updates, use the /message/stream endpoint instead.
|
||||
"""
|
||||
try:
|
||||
# Validate and get tools for the request
|
||||
tool_ids = await chat_service.validate_and_get_tools_for_request(
|
||||
user_id=currentUser.id,
|
||||
requested_tool_ids=message_request.tools,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Get or create thread using helper function
|
||||
thread_id = await get_or_create_thread_for_user(
|
||||
thread_id=message_request.thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
thread_name=message_request.message[:100],
|
||||
refresh_date_updated=True,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
||||
|
||||
response = await chat_service.post_message(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
tool_ids=tool_ids,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error posting chat message: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to post message: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads", response_model=ThreadListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_threads(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadListResponse:
|
||||
"""
|
||||
Get all chat threads for the current user.
|
||||
"""
|
||||
try:
|
||||
# Get all threads for the current user
|
||||
threads = await chat_service.get_all_threads_for_user(
|
||||
user=currentUser, session=session
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved {len(threads)} threads")
|
||||
|
||||
return ThreadListResponse(threads=threads)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving threads: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve threads: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}", response_model=ThreadDetail)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_thread_by_id(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ThreadDetail:
|
||||
"""
|
||||
Get a specific chat thread with all its messages from LangGraph checkpointer.
|
||||
"""
|
||||
try:
|
||||
thread_detail = await chat_service.get_thread_detail_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved thread {thread_id}")
|
||||
return thread_detail
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def rename_thread(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
rename_request: RenameThreadRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Rename a chat thread.
|
||||
"""
|
||||
try:
|
||||
await chat_service.update_thread_name(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
new_thread_name=rename_request.new_name,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} renamed thread {thread_id} to '{rename_request.new_name}'"
|
||||
)
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully renamed",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error renaming thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to rename thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/threads/{thread_id}", response_model=DeleteResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_thread(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> DeleteResponse:
|
||||
"""
|
||||
Delete a chat thread and all its associated data from both LangGraph and database.
|
||||
"""
|
||||
try:
|
||||
await chat_service.delete_thread_for_user(
|
||||
thread_id=thread_id,
|
||||
user=currentUser,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"User {currentUser.id} deleted thread {thread_id}")
|
||||
|
||||
return DeleteResponse(
|
||||
message=f"Thread {thread_id} successfully deleted",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Thread not found or permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e) or "Thread not found or permission denied",
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.error(f"Permission denied: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error deleting thread {thread_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete thread: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
# Tool Management Endpoints
|
||||
|
||||
|
||||
@router.get("/tools", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_tools(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all available chatbot tools.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can view tools",
|
||||
)
|
||||
|
||||
# Get all tools from service
|
||||
tools_data = await chat_service.get_all_tools(session=session)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(f"User {currentUser.id} retrieved {len(tools)} tools")
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tools/grant", response_model=GrantToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def grant_tool_to_user(
|
||||
*,
|
||||
request: Request,
|
||||
grant_request: GrantToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> GrantToolResponse:
|
||||
"""
|
||||
Grant a tool to a user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can grant tools",
|
||||
)
|
||||
|
||||
# Grant the tool
|
||||
await chat_service.grant_tool_to_user(
|
||||
user_id=grant_request.user_id,
|
||||
tool_id=grant_request.tool_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} granted tool {grant_request.tool_id} to user {grant_request.user_id}"
|
||||
)
|
||||
|
||||
return GrantToolResponse(
|
||||
message=f"Tool successfully granted to user {grant_request.user_id}",
|
||||
user_id=grant_request.user_id,
|
||||
tool_id=grant_request.tool_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error granting tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to grant tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tools/revoke", response_model=RevokeToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def revoke_tool_from_user(
|
||||
*,
|
||||
request: Request,
|
||||
revoke_request: RevokeToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> RevokeToolResponse:
|
||||
"""
|
||||
Revoke a tool from a user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can revoke tools",
|
||||
)
|
||||
|
||||
# Revoke the tool
|
||||
await chat_service.revoke_tool_from_user(
|
||||
user_id=revoke_request.user_id,
|
||||
tool_id=revoke_request.tool_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} revoked tool {revoke_request.tool_id} from user {revoke_request.user_id}"
|
||||
)
|
||||
|
||||
return RevokeToolResponse(
|
||||
message=f"Tool successfully revoked from user {revoke_request.user_id}",
|
||||
user_id=revoke_request.user_id,
|
||||
tool_id=revoke_request.tool_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error revoking tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to revoke tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/tools/{tool_id}", response_model=UpdateToolResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_tool(
|
||||
*,
|
||||
request: Request,
|
||||
tool_id: str,
|
||||
update_request: UpdateToolRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> UpdateToolResponse:
|
||||
"""
|
||||
Update a tool's label and/or description.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can update tools",
|
||||
)
|
||||
|
||||
# Update the tool
|
||||
updated_fields = await chat_service.update_tool(
|
||||
tool_id=tool_id,
|
||||
label=update_request.label,
|
||||
description=update_request.description,
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} updated tool {tool_id}, fields: {updated_fields}"
|
||||
)
|
||||
|
||||
return UpdateToolResponse(
|
||||
message="Tool successfully updated",
|
||||
tool_id=tool_id,
|
||||
updated_fields=updated_fields,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e) or "Invalid request",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating tool: {type(e).__name__}: {str(e)}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to update tool: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/user/{user_id}", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_tools_for_specific_user(
|
||||
*,
|
||||
request: Request,
|
||||
user_id: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all tools granted to a specific user.
|
||||
Only accessible to system administrators.
|
||||
"""
|
||||
try:
|
||||
# Check SYSADMIN permission
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only system administrators can view user tools",
|
||||
)
|
||||
|
||||
# Get tools for the specified user
|
||||
tools_data = await chat_service.get_tools_for_user(
|
||||
user_id=user_id, session=session
|
||||
)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} retrieved {len(tools)} tools for user {user_id}"
|
||||
)
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools for user {user_id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve tools for user: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/me", response_model=ToolListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_my_tools(
|
||||
*,
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
session: AsyncSession = Depends(get_async_db_session),
|
||||
) -> ToolListResponse:
|
||||
"""
|
||||
Get all tools the current user has access to.
|
||||
"""
|
||||
try:
|
||||
# Get tools for the current user
|
||||
tools_data = await chat_service.get_tools_for_user(
|
||||
user_id=currentUser.id, session=session
|
||||
)
|
||||
|
||||
# Convert to ToolInfo objects
|
||||
tools = [ToolInfo(**tool) for tool in tools_data]
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} retrieved {len(tools)} tools for themselves"
|
||||
)
|
||||
|
||||
return ToolListResponse(tools=tools)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving tools for user {currentUser.id}: {type(e).__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve your tools: {type(e).__name__}: {str(e) or 'No error message provided'}",
|
||||
)
|
||||
|
|
@ -2,13 +2,14 @@
|
|||
Shared utilities for model attributes and labels.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, Any, List, Type, Optional, Union
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ModelMixin:
|
||||
"""Mixin class that provides serialization methods for Pydantic models."""
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ class ModelMixin:
|
|||
Dict[str, Any]: Dictionary representation of the model
|
||||
"""
|
||||
# Get the raw dictionary
|
||||
if hasattr(self, 'model_dump'):
|
||||
if hasattr(self, "model_dump"):
|
||||
data: Dict[str, Any] = self.model_dump() # Pydantic v2
|
||||
else:
|
||||
data: Dict[str, Any] = self.dict() # Pydantic v1
|
||||
|
|
@ -33,7 +34,7 @@ class ModelMixin:
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelMixin':
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ModelMixin":
|
||||
"""
|
||||
Create a Pydantic model instance from a dictionary.
|
||||
|
||||
|
|
@ -45,9 +46,11 @@ class ModelMixin:
|
|||
"""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
# Define the AttributeDefinition class here instead of importing it
|
||||
class AttributeDefinition(BaseModel, ModelMixin):
|
||||
"""Definition of a model attribute with its metadata."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
label: str
|
||||
|
|
@ -64,9 +67,11 @@ class AttributeDefinition(BaseModel, ModelMixin):
|
|||
order: int = 0
|
||||
placeholder: Optional[str] = None
|
||||
|
||||
|
||||
# Global registry for model labels
|
||||
MODEL_LABELS: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
||||
|
||||
def to_dict(model: BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
|
|
@ -78,10 +83,11 @@ def to_dict(model: BaseModel) -> Dict[str, Any]:
|
|||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the model
|
||||
"""
|
||||
if hasattr(model, 'model_dump'):
|
||||
if hasattr(model, "model_dump"):
|
||||
return model.model_dump() # Pydantic v2
|
||||
return model.dict() # Pydantic v1
|
||||
|
||||
|
||||
def from_dict(model_class: Type[BaseModel], data: Dict[str, Any]) -> BaseModel:
|
||||
"""
|
||||
Create a Pydantic model instance from a dictionary.
|
||||
|
|
@ -95,7 +101,10 @@ def from_dict(model_class: Type[BaseModel], data: Dict[str, Any]) -> BaseModel:
|
|||
"""
|
||||
return model_class(**data)
|
||||
|
||||
def register_model_labels(model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]):
|
||||
|
||||
def register_model_labels(
|
||||
model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]
|
||||
):
|
||||
"""
|
||||
Register labels for a model's attributes and the model itself.
|
||||
|
||||
|
|
@ -106,10 +115,8 @@ def register_model_labels(model_name: str, model_label: Dict[str, str], labels:
|
|||
labels: Dictionary mapping attribute names to their translations
|
||||
e.g. {"name": {"en": "Name", "fr": "Nom"}}
|
||||
"""
|
||||
MODEL_LABELS[model_name] = {
|
||||
"model": model_label,
|
||||
"attributes": labels
|
||||
}
|
||||
MODEL_LABELS[model_name] = {"model": model_label, "attributes": labels}
|
||||
|
||||
|
||||
def get_model_labels(model_name: str, language: str = "en") -> Dict[str, str]:
|
||||
"""
|
||||
|
|
@ -130,6 +137,7 @@ def get_model_labels(model_name: str, language: str = "en") -> Dict[str, str]:
|
|||
for attr, translations in attribute_labels.items()
|
||||
}
|
||||
|
||||
|
||||
def get_model_label(model_name: str, language: str = "en") -> str:
|
||||
"""
|
||||
Get the label for a model in the specified language.
|
||||
|
|
@ -145,7 +153,10 @@ def get_model_label(model_name: str, language: str = "en") -> str:
|
|||
model_label = model_data.get("model", {})
|
||||
return model_label.get(language, model_label.get("en", model_name))
|
||||
|
||||
def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguage: str = "en") -> Dict[str, Any]:
|
||||
|
||||
def getModelAttributeDefinitions(
|
||||
modelClass: Type[BaseModel] = None, userLanguage: str = "en"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get attribute definitions for a model class.
|
||||
|
||||
|
|
@ -165,11 +176,11 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
|
|||
model_label = get_model_label(model_name, userLanguage)
|
||||
|
||||
# Handle both Pydantic v1 and v2
|
||||
if hasattr(modelClass, 'model_fields'): # Pydantic v2
|
||||
if hasattr(modelClass, "model_fields"): # Pydantic v2
|
||||
fields = modelClass.model_fields
|
||||
for name, field in fields.items():
|
||||
# Extract frontend metadata from field info
|
||||
field_info = field.field_info if hasattr(field, 'field_info') else None
|
||||
field_info = field.field_info if hasattr(field, "field_info") else None
|
||||
# Check both direct attributes and extra field for frontend metadata
|
||||
frontend_type = None
|
||||
frontend_readonly = False
|
||||
|
|
@ -178,43 +189,63 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
|
|||
|
||||
if field_info:
|
||||
# Try direct attributes first
|
||||
frontend_type = getattr(field_info, 'frontend_type', None)
|
||||
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
|
||||
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
|
||||
frontend_options = getattr(field_info, 'frontend_options', None)
|
||||
frontend_type = getattr(field_info, "frontend_type", None)
|
||||
frontend_readonly = getattr(field_info, "frontend_readonly", False)
|
||||
frontend_required = getattr(
|
||||
field_info, "frontend_required", frontend_required
|
||||
)
|
||||
frontend_options = getattr(field_info, "frontend_options", None)
|
||||
|
||||
# If not found, check extra field
|
||||
if hasattr(field_info, 'extra') and field_info.extra:
|
||||
if hasattr(field_info, "extra") and field_info.extra:
|
||||
if frontend_type is None:
|
||||
frontend_type = field_info.extra.get('frontend_type')
|
||||
frontend_type = field_info.extra.get("frontend_type")
|
||||
if not frontend_readonly:
|
||||
frontend_readonly = field_info.extra.get('frontend_readonly', False)
|
||||
if frontend_required == field.is_required(): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get('frontend_required', frontend_required)
|
||||
frontend_readonly = field_info.extra.get(
|
||||
"frontend_readonly", False
|
||||
)
|
||||
if (
|
||||
frontend_required == field.is_required()
|
||||
): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get(
|
||||
"frontend_required", frontend_required
|
||||
)
|
||||
if frontend_options is None:
|
||||
frontend_options = field_info.extra.get('frontend_options')
|
||||
frontend_options = field_info.extra.get("frontend_options")
|
||||
|
||||
# Use frontend type if available, otherwise fall back to Python type
|
||||
field_type = frontend_type if frontend_type else (field.annotation.__name__ if hasattr(field.annotation, "__name__") else str(field.annotation))
|
||||
field_type = (
|
||||
frontend_type
|
||||
if frontend_type
|
||||
else (
|
||||
field.annotation.__name__
|
||||
if hasattr(field.annotation, "__name__")
|
||||
else str(field.annotation)
|
||||
)
|
||||
)
|
||||
|
||||
attributes.append({
|
||||
attributes.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.description if hasattr(field, "description") else "",
|
||||
"description": field.description
|
||||
if hasattr(field, "description")
|
||||
else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options
|
||||
})
|
||||
"options": frontend_options,
|
||||
}
|
||||
)
|
||||
else: # Pydantic v1
|
||||
fields = modelClass.__fields__
|
||||
for name, field in fields.items():
|
||||
# Extract frontend metadata from field info
|
||||
field_info = field.field_info if hasattr(field, 'field_info') else None
|
||||
field_info = field.field_info if hasattr(field, "field_info") else None
|
||||
# Check both direct attributes and extra field for frontend metadata
|
||||
frontend_type = None
|
||||
frontend_readonly = False
|
||||
|
|
@ -223,43 +254,61 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
|
|||
|
||||
if field_info:
|
||||
# Try direct attributes first
|
||||
frontend_type = getattr(field_info, 'frontend_type', None)
|
||||
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
|
||||
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
|
||||
frontend_options = getattr(field_info, 'frontend_options', None)
|
||||
frontend_type = getattr(field_info, "frontend_type", None)
|
||||
frontend_readonly = getattr(field_info, "frontend_readonly", False)
|
||||
frontend_required = getattr(
|
||||
field_info, "frontend_required", frontend_required
|
||||
)
|
||||
frontend_options = getattr(field_info, "frontend_options", None)
|
||||
|
||||
# If not found, check extra field
|
||||
if hasattr(field_info, 'extra') and field_info.extra:
|
||||
if hasattr(field_info, "extra") and field_info.extra:
|
||||
if frontend_type is None:
|
||||
frontend_type = field_info.extra.get('frontend_type')
|
||||
frontend_type = field_info.extra.get("frontend_type")
|
||||
if not frontend_readonly:
|
||||
frontend_readonly = field_info.extra.get('frontend_readonly', False)
|
||||
if frontend_required == field.required: # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get('frontend_required', frontend_required)
|
||||
frontend_readonly = field_info.extra.get(
|
||||
"frontend_readonly", False
|
||||
)
|
||||
if (
|
||||
frontend_required == field.required
|
||||
): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get(
|
||||
"frontend_required", frontend_required
|
||||
)
|
||||
if frontend_options is None:
|
||||
frontend_options = field_info.extra.get('frontend_options')
|
||||
frontend_options = field_info.extra.get("frontend_options")
|
||||
|
||||
# Use frontend type if available, otherwise fall back to Python type
|
||||
field_type = frontend_type if frontend_type else (field.type_.__name__ if hasattr(field.type_, "__name__") else str(field.type_))
|
||||
field_type = (
|
||||
frontend_type
|
||||
if frontend_type
|
||||
else (
|
||||
field.type_.__name__
|
||||
if hasattr(field.type_, "__name__")
|
||||
else str(field.type_)
|
||||
)
|
||||
)
|
||||
|
||||
attributes.append({
|
||||
attributes.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.field_info.description if hasattr(field.field_info, "description") else "",
|
||||
"description": field.field_info.description
|
||||
if hasattr(field.field_info, "description")
|
||||
else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options
|
||||
})
|
||||
|
||||
return {
|
||||
"model": model_label,
|
||||
"attributes": attributes
|
||||
"options": frontend_options,
|
||||
}
|
||||
)
|
||||
|
||||
return {"model": model_label, "attributes": attributes}
|
||||
|
||||
|
||||
def getModelClasses() -> Dict[str, Type[BaseModel]]:
|
||||
"""
|
||||
|
|
@ -271,47 +320,61 @@ def getModelClasses() -> Dict[str, Type[BaseModel]]:
|
|||
modelClasses = {}
|
||||
|
||||
# Get the interfaces directory path
|
||||
interfaces_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'interfaces')
|
||||
interfaces_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "interfaces"
|
||||
)
|
||||
|
||||
# Find all model files in interfaces directory
|
||||
for fileName in os.listdir(interfaces_dir):
|
||||
if fileName.endswith('Model.py'):
|
||||
if fileName.endswith("Model.py"):
|
||||
# Convert fileName to module name (e.g., gatewayModel.py -> gatewayModel)
|
||||
module_name = fileName[:-3]
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(f'modules.interfaces.{module_name}')
|
||||
module = importlib.import_module(f"modules.interfaces.{module_name}")
|
||||
|
||||
# Get all classes from the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, BaseModel)
|
||||
and obj != BaseModel
|
||||
):
|
||||
modelClasses[name] = obj
|
||||
|
||||
# Also get models from datamodels directory
|
||||
datamodels_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'datamodels')
|
||||
datamodels_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "datamodels"
|
||||
)
|
||||
|
||||
# Find all model files in datamodels directory
|
||||
for fileName in os.listdir(datamodels_dir):
|
||||
if fileName.startswith('datamodel') and fileName.endswith('.py'):
|
||||
if fileName.startswith("datamodel") and fileName.endswith(".py"):
|
||||
# Convert fileName to module name (e.g., datamodelUtils.py -> datamodelUtils)
|
||||
module_name = fileName[:-3]
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(f'modules.datamodels.{module_name}')
|
||||
module = importlib.import_module(f"modules.datamodels.{module_name}")
|
||||
|
||||
# Get all classes from the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, BaseModel)
|
||||
and obj != BaseModel
|
||||
):
|
||||
modelClasses[name] = obj
|
||||
|
||||
return modelClasses
|
||||
|
||||
|
||||
class AttributeResponse(BaseModel):
|
||||
"""Response model for entity attributes"""
|
||||
|
||||
attributes: List[AttributeDefinition]
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"attributes": [
|
||||
{
|
||||
|
|
@ -322,8 +385,9 @@ class AttributeResponse(BaseModel):
|
|||
"placeholder": "Please enter username",
|
||||
"editable": True,
|
||||
"visible": True,
|
||||
"order": 0
|
||||
"order": 0,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
## Web Framework & API
|
||||
fastapi==0.104.1
|
||||
fastapi==0.115.0 # Upgraded for Pydantic v2 compatibility
|
||||
websockets==12.0
|
||||
uvicorn==0.23.2
|
||||
python-multipart==0.0.6
|
||||
httpx==0.25.0
|
||||
pydantic==1.10.13 # Ältere Version ohne Rust-Abhängigkeit
|
||||
httpx>=0.25.2
|
||||
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
|
||||
email-validator==2.0.0 # Required by Pydantic for email validation
|
||||
slowapi==0.1.8 # For rate limiting
|
||||
|
||||
|
|
@ -109,3 +109,14 @@ xyzservices>=2021.09.1
|
|||
|
||||
# PostgreSQL connector dependencies
|
||||
psycopg2-binary==2.9.9
|
||||
|
||||
## LangChain & LangGraph
|
||||
langchain==0.3.27
|
||||
langgraph==0.6.8
|
||||
langchain-core==0.3.77
|
||||
langchain-anthropic==0.3.1 # For Claude models
|
||||
psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer)
|
||||
psycopg-pool==3.2.1 # Connection pooling for PostgreSQL
|
||||
langgraph-checkpoint-postgres==2.0.24
|
||||
|
||||
greenlet==3.2.4
|
||||
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
198
tests/features/chatBot/utils/test_toolRegistry.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""Pytest tests for the tool registry.
|
||||
|
||||
This module tests that the tool registry correctly discovers and catalogs
|
||||
all tools in the chatbotTools directory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
from modules.features.chatBot.utils.toolRegistry import (
|
||||
ToolMetadata,
|
||||
ToolRegistry,
|
||||
get_registry,
|
||||
reinitialize_registry,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Test suite for ToolRegistry class."""
|
||||
|
||||
@pytest.fixture
|
||||
def registry(self) -> ToolRegistry:
|
||||
"""Provide a fresh registry instance for each test."""
|
||||
return reinitialize_registry()
|
||||
|
||||
def test_registry_initialization(self, registry: ToolRegistry) -> None:
|
||||
"""Test that registry initializes correctly."""
|
||||
assert registry.is_initialized
|
||||
assert isinstance(registry._tools, dict)
|
||||
|
||||
def test_get_all_tools(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all registered tools."""
|
||||
all_tools = registry.get_all_tools()
|
||||
assert isinstance(all_tools, list)
|
||||
assert len(all_tools) > 0
|
||||
assert all(isinstance(tool, ToolMetadata) for tool in all_tools)
|
||||
|
||||
# Log all discovered tools
|
||||
logger.info(f"Found {len(all_tools)} tools in registry:")
|
||||
for tool in all_tools:
|
||||
logger.info(f"\n{tool}")
|
||||
|
||||
def test_tool_metadata_structure(self, registry: ToolRegistry) -> None:
|
||||
"""Test that tool metadata has correct structure."""
|
||||
all_tools = registry.get_all_tools()
|
||||
for tool in all_tools:
|
||||
assert isinstance(tool.tool_id, str)
|
||||
assert isinstance(tool.name, str)
|
||||
assert isinstance(tool.category, str)
|
||||
assert tool.category in ["shared", "customer"]
|
||||
assert isinstance(tool.description, str)
|
||||
assert isinstance(tool.tool_instance, BaseTool)
|
||||
assert isinstance(tool.module_path, str)
|
||||
|
||||
def test_list_tool_ids(self, registry: ToolRegistry) -> None:
|
||||
"""Test listing all tool IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
assert isinstance(tool_ids, list)
|
||||
assert len(tool_ids) > 0
|
||||
assert all(isinstance(tool_id, str) for tool_id in tool_ids)
|
||||
|
||||
# Check that tool IDs follow expected format
|
||||
for tool_id in tool_ids:
|
||||
assert "." in tool_id
|
||||
category, name = tool_id.split(".", 1)
|
||||
assert category in ["shared", "customer"]
|
||||
|
||||
def test_get_specific_tool(self, registry: ToolRegistry) -> None:
|
||||
"""Test retrieving a specific tool by ID."""
|
||||
# Get all tool IDs first
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if tool_ids:
|
||||
# Test with first available tool
|
||||
test_tool_id = tool_ids[0]
|
||||
tool_metadata = registry.get_tool(tool_id=test_tool_id)
|
||||
|
||||
assert tool_metadata is not None
|
||||
assert isinstance(tool_metadata, ToolMetadata)
|
||||
assert tool_metadata.tool_id == test_tool_id
|
||||
|
||||
def test_get_nonexistent_tool(self, registry: ToolRegistry) -> None:
|
||||
"""Test retrieving a tool that doesn't exist."""
|
||||
tool_metadata = registry.get_tool(tool_id="nonexistent.tool")
|
||||
assert tool_metadata is None
|
||||
|
||||
def test_get_tools_by_category_shared(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all shared tools."""
|
||||
shared_tools = registry.get_tools_by_category(category="shared")
|
||||
assert isinstance(shared_tools, list)
|
||||
assert all(tool.category == "shared" for tool in shared_tools)
|
||||
|
||||
def test_get_tools_by_category_customer(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting all customer tools."""
|
||||
customer_tools = registry.get_tools_by_category(category="customer")
|
||||
assert isinstance(customer_tools, list)
|
||||
assert all(tool.category == "customer" for tool in customer_tools)
|
||||
|
||||
def test_get_tool_instances(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting tool instances by IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if len(tool_ids) >= 2:
|
||||
# Test with first two tools
|
||||
test_ids = tool_ids[:2]
|
||||
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||
|
||||
assert isinstance(instances, list)
|
||||
assert len(instances) == 2
|
||||
assert all(isinstance(inst, BaseTool) for inst in instances)
|
||||
|
||||
def test_get_tool_instances_with_invalid_id(self, registry: ToolRegistry) -> None:
|
||||
"""Test getting tool instances with some invalid IDs."""
|
||||
tool_ids = registry.list_tool_ids()
|
||||
if tool_ids:
|
||||
# Mix valid and invalid IDs
|
||||
test_ids = [tool_ids[0], "invalid.tool"]
|
||||
instances = registry.get_tool_instances(tool_ids=test_ids)
|
||||
|
||||
# Should only return the valid one
|
||||
assert len(instances) == 1
|
||||
assert isinstance(instances[0], BaseTool)
|
||||
|
||||
def test_global_registry_singleton(self) -> None:
|
||||
"""Test that get_registry returns same instance."""
|
||||
registry1 = get_registry()
|
||||
registry2 = get_registry()
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_reinitialize_registry(self) -> None:
|
||||
"""Test that reinitialize creates new instance."""
|
||||
registry1 = get_registry()
|
||||
registry2 = reinitialize_registry()
|
||||
# Should be different instances after reinitialize
|
||||
assert registry1 is not registry2
|
||||
assert registry2.is_initialized
|
||||
|
||||
|
||||
class TestToolDiscovery:
|
||||
"""Test suite for tool discovery functionality."""
|
||||
|
||||
def test_discovers_at_least_one_tool(self) -> None:
|
||||
"""Test that at least one tool is discovered."""
|
||||
registry = get_registry()
|
||||
tool_ids = registry.list_tool_ids()
|
||||
|
||||
# At least one tool should be successfully loaded
|
||||
assert len(tool_ids) >= 1, "Expected at least one tool to be discovered"
|
||||
|
||||
def test_query_althaus_database_if_available(self) -> None:
|
||||
"""Test query_althaus_database tool if it was successfully loaded."""
|
||||
registry = get_registry()
|
||||
tool = registry.get_tool(tool_id="customer.query_althaus_database")
|
||||
|
||||
if tool is not None:
|
||||
assert tool.name == "query_althaus_database"
|
||||
assert tool.category == "customer"
|
||||
assert "database" in tool.description.lower()
|
||||
else:
|
||||
# Tool may not have loaded due to import errors - log warning
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"customer.query_althaus_database tool not found - "
|
||||
"may have failed to import"
|
||||
)
|
||||
|
||||
def test_tavily_search_if_available(self) -> None:
|
||||
"""Test tavily_search tool if it was successfully loaded."""
|
||||
registry = get_registry()
|
||||
tool = registry.get_tool(tool_id="shared.tavily_search")
|
||||
|
||||
if tool is not None:
|
||||
assert tool.name == "tavily_search"
|
||||
assert tool.category == "shared"
|
||||
assert "search" in tool.description.lower()
|
||||
else:
|
||||
# Tool may not have loaded due to import errors - log warning
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"shared.tavily_search tool not found - may have failed to import"
|
||||
)
|
||||
|
||||
def test_tool_ids_have_correct_format(self) -> None:
|
||||
"""Test that all discovered tool IDs follow the expected format."""
|
||||
registry = get_registry()
|
||||
tool_ids = registry.list_tool_ids()
|
||||
|
||||
for tool_id in tool_ids:
|
||||
# All tool IDs should have format: category.toolname
|
||||
assert "." in tool_id, f"Tool ID {tool_id} missing category separator"
|
||||
category, name = tool_id.split(".", 1)
|
||||
assert category in [
|
||||
"shared",
|
||||
"customer",
|
||||
], f"Tool {tool_id} has invalid category: {category}"
|
||||
assert len(name) > 0, f"Tool {tool_id} has empty name"
|
||||
Loading…
Reference in a new issue