database attached

This commit is contained in:
ValueOn AG 2025-09-08 12:45:03 +02:00
parent 98a4323b36
commit 8fbbd35055
70 changed files with 1923 additions and 1259 deletions

7
app.py
View file

@ -63,10 +63,11 @@ def initLogging():
class EmojiFilter(logging.Filter): class EmojiFilter(logging.Filter):
def filter(self, record): def filter(self, record):
if isinstance(record.msg, str): if isinstance(record.msg, str):
# Remove emojis and other Unicode characters that might cause encoding issues # Remove only emojis, preserve other Unicode characters like quotes
import re import re
# Remove emojis and other Unicode symbols import unicodedata
record.msg = re.sub(r'[^\x00-\x7F]+', '[EMOJI]', record.msg) # Remove emoji characters specifically
record.msg = ''.join(char for char in record.msg if unicodedata.category(char) != 'So' or not (0x1F600 <= ord(char) <= 0x1F64F or 0x1F300 <= ord(char) <= 0x1F5FF or 0x1F680 <= ord(char) <= 0x1F6FF or 0x1F1E0 <= ord(char) <= 0x1F1FF or 0x2600 <= ord(char) <= 0x26FF or 0x2700 <= ord(char) <= 0x27BF))
return True return True
# Configure handlers based on config # Configure handlers based on config

69
env_dev.env Normal file
View file

@ -0,0 +1,69 @@
# Development Environment Configuration
# System Configuration
APP_ENV_TYPE = dev
APP_ENV_LABEL = Development Instance Patrick
APP_API_URL = http://localhost:8000
# Database Configuration for Application
# JSON File Storage (current)
# DB_APP_HOST=D:/Temp/_powerondb
# DB_APP_DATABASE=app
# DB_APP_USER=dev_user
# DB_APP_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_APP_HOST=localhost
DB_APP_DATABASE=poweron_app_dev
DB_APP_USER=poweron_dev
DB_APP_PASSWORD_SECRET=dev_password
DB_APP_PORT=5432
# Database Configuration Chat
# JSON File Storage (current)
# DB_CHAT_HOST=D:/Temp/_powerondb
# DB_CHAT_DATABASE=chat
# DB_CHAT_USER=dev_user
# DB_CHAT_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_CHAT_HOST=localhost
DB_CHAT_DATABASE=poweron_chat_dev
DB_CHAT_USER=poweron_dev
DB_CHAT_PASSWORD_SECRET=dev_password
DB_CHAT_PORT=5432
# Database Configuration Management
# JSON File Storage (current)
# DB_MANAGEMENT_HOST=D:/Temp/_powerondb
# DB_MANAGEMENT_DATABASE=management
# DB_MANAGEMENT_USER=dev_user
# DB_MANAGEMENT_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_MANAGEMENT_HOST=localhost
DB_MANAGEMENT_DATABASE=poweron_management_dev
DB_MANAGEMENT_USER=poweron_dev
DB_MANAGEMENT_PASSWORD_SECRET=dev_password
DB_MANAGEMENT_PORT=5432
# Security Configuration
APP_JWT_SECRET_SECRET=dev_jwt_secret_token
APP_TOKEN_EXPIRY=300
# CORS Configuration
APP_ALLOWED_ORIGINS=http://localhost:8080,https://playground.poweron-center.net
# Logging configuration
APP_LOGGING_LOG_LEVEL = DEBUG
APP_LOGGING_LOG_FILE = poweron.log
APP_LOGGING_FORMAT = %(asctime)s - %(levelname)s - %(name)s - %(message)s
APP_LOGGING_DATE_FORMAT = %Y-%m-%d %H:%M:%S
APP_LOGGING_CONSOLE_ENABLED = True
APP_LOGGING_FILE_ENABLED = True
APP_LOGGING_ROTATION_SIZE = 10485760
APP_LOGGING_BACKUP_COUNT = 5
# Service Redirects
Service_MSFT_REDIRECT_URI = http://localhost:8000/api/msft/auth/callback
Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback

View file

@ -6,22 +6,46 @@ APP_ENV_LABEL = Integration Instance
APP_API_URL = https://gateway-int.poweron-center.net APP_API_URL = https://gateway-int.poweron-center.net
# Database Configuration Application # Database Configuration Application
DB_APP_HOST=/home/_powerondb # JSON File Storage (current)
DB_APP_DATABASE=app # DB_APP_HOST=/home/_powerondb
DB_APP_USER=dev_user # DB_APP_DATABASE=app
DB_APP_PASSWORD_SECRET=dev_password # DB_APP_USER=dev_user
# DB_APP_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_APP_HOST=gateway-int-db.poweron-center.net
DB_APP_DATABASE=poweron_app_int
DB_APP_USER=poweron_int
DB_APP_PASSWORD_SECRET=int_password_secure
DB_APP_PORT=5432
# Database Configuration Chat # Database Configuration Chat
DB_CHAT_HOST=/home/_powerondb # JSON File Storage (current)
DB_CHAT_DATABASE=chat # DB_CHAT_HOST=/home/_powerondb
DB_CHAT_USER=dev_user # DB_CHAT_DATABASE=chat
DB_CHAT_PASSWORD_SECRET=dev_password # DB_CHAT_USER=dev_user
# DB_CHAT_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_CHAT_HOST=gateway-int-db.poweron-center.net
DB_CHAT_DATABASE=poweron_chat_int
DB_CHAT_USER=poweron_int
DB_CHAT_PASSWORD_SECRET=int_password_secure
DB_CHAT_PORT=5432
# Database Configuration Management # Database Configuration Management
DB_MANAGEMENT_HOST=/home/_powerondb # JSON File Storage (current)
DB_MANAGEMENT_DATABASE=management # DB_MANAGEMENT_HOST=/home/_powerondb
DB_MANAGEMENT_USER=dev_user # DB_MANAGEMENT_DATABASE=management
DB_MANAGEMENT_PASSWORD_SECRET=dev_password # DB_MANAGEMENT_USER=dev_user
# DB_MANAGEMENT_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_MANAGEMENT_HOST=gateway-int-db.poweron-center.net
DB_MANAGEMENT_DATABASE=poweron_management_int
DB_MANAGEMENT_USER=poweron_int
DB_MANAGEMENT_PASSWORD_SECRET=int_password_secure
DB_MANAGEMENT_PORT=5432
# Security Configuration # Security Configuration
APP_JWT_SECRET_SECRET=dev_jwt_secret_token APP_JWT_SECRET_SECRET=dev_jwt_secret_token

View file

@ -6,22 +6,46 @@ APP_ENV_LABEL = Production Instance
APP_API_URL = https://gateway.poweron-center.net APP_API_URL = https://gateway.poweron-center.net
# Database Configuration Application # Database Configuration Application
DB_APP_HOST=/home/_powerondb # JSON File Storage (current)
DB_APP_DATABASE=app # DB_APP_HOST=/home/_powerondb
DB_APP_USER=dev_user # DB_APP_DATABASE=app
DB_APP_PASSWORD_SECRET=dev_password # DB_APP_USER=dev_user
# DB_APP_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_APP_HOST=gateway-prod-server.postgres.database.azure.com
DB_APP_DATABASE=gateway-app
DB_APP_USER=gzxxmcrdhn
DB_APP_PASSWORD_SECRET=prod_password_very_secure.2025
DB_APP_PORT=5432
# Database Configuration Chat # Database Configuration Chat
DB_CHAT_HOST=/home/_powerondb # JSON File Storage (current)
DB_CHAT_DATABASE=chat # DB_CHAT_HOST=/home/_powerondb
DB_CHAT_USER=dev_user # DB_CHAT_DATABASE=chat
DB_CHAT_PASSWORD_SECRET=dev_password # DB_CHAT_USER=gzxxmcrdhn
# DB_CHAT_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_CHAT_HOST=gateway-prod-server.postgres.database.azure.com
DB_CHAT_DATABASE=gateway-chat
DB_CHAT_USER=poweron_prod
DB_CHAT_PASSWORD_SECRET=prod_password_very_secure.2025
DB_CHAT_PORT=5432
# Database Configuration Management # Database Configuration Management
DB_MANAGEMENT_HOST=/home/_powerondb # JSON File Storage (current)
DB_MANAGEMENT_DATABASE=management # DB_MANAGEMENT_HOST=/home/_powerondb
DB_MANAGEMENT_USER=dev_user # DB_MANAGEMENT_DATABASE=gateway-management
DB_MANAGEMENT_PASSWORD_SECRET=dev_password # DB_MANAGEMENT_USER=gzxxmcrdhn
# DB_MANAGEMENT_PASSWORD_SECRET=dev_password
# PostgreSQL Storage (new)
DB_MANAGEMENT_HOST=gateway-prod-server.postgres.database.azure.com
DB_MANAGEMENT_DATABASE=gateway-management
DB_MANAGEMENT_USER=poweron_prod
DB_MANAGEMENT_PASSWORD_SECRET=prod_password_very_secure.2025
DB_MANAGEMENT_PORT=5432
# Security Configuration # Security Configuration
APP_JWT_SECRET_SECRET=dev_jwt_secret_token APP_JWT_SECRET_SECRET=dev_jwt_secret_token

View file

@ -66,7 +66,7 @@ class DocumentGenerator:
logger.error(f"Error processing single document: {str(e)}") logger.error(f"Error processing single document: {str(e)}")
return None return None
def createDocumentsFromActionResult(self, action_result, action, workflow) -> List[Any]: def createDocumentsFromActionResult(self, action_result, action, workflow, message_id=None) -> List[Any]:
""" """
Create actual document objects from action result and store them in the system. Create actual document objects from action result and store them in the system.
Returns a list of created document objects with proper workflow context. Returns a list of created document objects with proper workflow context.
@ -103,7 +103,8 @@ class DocumentGenerator:
fileName=document_name, fileName=document_name,
mimeType=mime_type, mimeType=mime_type,
content=content, content=content,
base64encoded=False base64encoded=False,
messageId=message_id
) )
if document: if document:
# Set workflow context on the document if possible # Set workflow context on the document if possible

View file

@ -250,7 +250,7 @@ class HandlingTasks:
"taskProgress": "pending" "taskProgress": "pending"
} }
message = self.chatInterface.createWorkflowMessage(message_data) message = self.chatInterface.createMessage(message_data)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
@ -492,7 +492,7 @@ class HandlingTasks:
if task_step.userMessage: if task_step.userMessage:
task_start_message["message"] += f"\n\n💬 {task_step.userMessage}" task_start_message["message"] += f"\n\n💬 {task_step.userMessage}"
message = self.chatInterface.createWorkflowMessage(task_start_message) message = self.chatInterface.createMessage(task_start_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
logger.info(f"Task start message created for task {task_index}") logger.info(f"Task start message created for task {task_index}")
@ -569,7 +569,7 @@ class HandlingTasks:
"actionNumber": action_number "actionNumber": action_number
}) })
message = self.chatInterface.createWorkflowMessage(action_start_message) message = self.chatInterface.createMessage(action_start_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
logger.info(f"Action start message created for action {action_number}") logger.info(f"Action start message created for action {action_number}")
@ -623,7 +623,7 @@ class HandlingTasks:
"taskProgress": "success" "taskProgress": "success"
} }
message = self.chatInterface.createWorkflowMessage(task_completion_message) message = self.chatInterface.createMessage(task_completion_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
logger.info(f"Task completion message created for task {task_index}") logger.info(f"Task completion message created for task {task_index}")
@ -715,7 +715,7 @@ class HandlingTasks:
"taskProgress": "retry" "taskProgress": "retry"
} }
message = self.chatInterface.createWorkflowMessage(retry_message) message = self.chatInterface.createMessage(retry_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
@ -768,7 +768,7 @@ class HandlingTasks:
} }
try: try:
message = self.chatInterface.createWorkflowMessage(message_data) message = self.chatInterface.createMessage(message_data)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
logger.info(f"Created user-facing retry message for failed task: {task_step.objective}") logger.info(f"Created user-facing retry message for failed task: {task_step.objective}")
@ -822,7 +822,7 @@ class HandlingTasks:
} }
try: try:
message = self.chatInterface.createWorkflowMessage(message_data) message = self.chatInterface.createMessage(message_data)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
logger.info(f"Created user-facing error message for failed task: {task_step.objective}") logger.info(f"Created user-facing error message for failed task: {task_step.objective}")
@ -1030,8 +1030,11 @@ class HandlingTasks:
if "execParameters" not in actionData: if "execParameters" not in actionData:
actionData["execParameters"] = {} actionData["execParameters"] = {}
# Use generic field separation based on TaskAction model
simple_fields, object_fields = self.chatInterface._separate_object_fields(TaskAction, actionData)
# Create action in database # Create action in database
createdAction = self.chatInterface.db.recordCreate("taskActions", actionData) createdAction = self.chatInterface.db.recordCreate(TaskAction, simple_fields)
# Convert to TaskAction model # Convert to TaskAction model
return TaskAction( return TaskAction(
@ -1095,27 +1098,36 @@ class HandlingTasks:
) )
result_label = action.execResultLabel result_label = action.execResultLabel
# Process documents from the action result # Process documents from the action result
created_documents = [] created_documents = []
if result.success: if result.success:
created_documents = self.documentGenerator.createDocumentsFromActionResult(result, action, workflow) action.setSuccess()
action.setSuccess() # Extract result text from documents if available, otherwise use empty string
# Extract result text from documents if available, otherwise use empty string action.result = ""
action.result = "" if result.documents and len(result.documents) > 0:
if result.documents and len(result.documents) > 0: # Try to get text content from the first document
# Try to get text content from the first document first_doc = result.documents[0]
first_doc = result.documents[0] if isinstance(first_doc.documentData, dict):
if isinstance(first_doc.documentData, dict): action.result = first_doc.documentData.get("result", "")
action.result = first_doc.documentData.get("result", "") elif isinstance(first_doc.documentData, str):
elif isinstance(first_doc.documentData, str): action.result = first_doc.documentData
action.result = first_doc.documentData # Preserve the action's execResultLabel for document routing
# Preserve the action's execResultLabel for document routing # Action methods should NOT return resultLabel - this is managed by the action handler
# Action methods should NOT return resultLabel - this is managed by the action handler if not action.execResultLabel:
if not action.execResultLabel: logger.warning(f"Action {action.execMethod}.{action.execAction} has no execResultLabel set")
logger.warning(f"Action {action.execMethod}.{action.execAction} has no execResultLabel set") # Always use the action's execResultLabel for message creation to ensure proper document routing
# Always use the action's execResultLabel for message creation to ensure proper document routing message_result_label = action.execResultLabel
message_result_label = action.execResultLabel
await self.createActionMessage(action, result, workflow, message_result_label, created_documents, task_step, task_index) # Create message first to get messageId, then create documents with messageId
message = await self.createActionMessage(action, result, workflow, message_result_label, [], task_step, task_index)
if message:
# Now create documents with the messageId
created_documents = self.documentGenerator.createDocumentsFromActionResult(result, action, workflow, message.id)
# Update the message with the created documents
if created_documents:
message.documents = created_documents
# Update the message in the database
self.chatInterface.updateMessage(message.id, {"documents": [doc.to_dict() for doc in created_documents]})
# Log action results # Log action results
logger.info(f"Action completed successfully") logger.info(f"Action completed successfully")
@ -1138,10 +1150,10 @@ class HandlingTasks:
logger.error(f"Action failed: {result.error}") logger.error(f"Action failed: {result.error}")
# ⚠️ IMPORTANT: Create error message for failed actions so user can see what went wrong # ⚠️ IMPORTANT: Create error message for failed actions so user can see what went wrong
await self.createActionMessage(action, result, workflow, result_label, [], task_step, task_index) message = await self.createActionMessage(action, result, workflow, result_label, [], task_step, task_index)
# Create database log entry for action failure # Create database log entry for action failure
self.chatInterface.createWorkflowLog({ self.chatInterface.createLog({
"workflowId": workflow.id, "workflowId": workflow.id,
"message": f"❌ **Task {task_num}**\n\n❌ **Action {action_num}/{total_actions}** failed: {result.error}", "message": f"❌ **Task {task_num}**\n\n❌ **Action {action_num}/{total_actions}** failed: {result.error}",
"type": "error" "type": "error"
@ -1237,14 +1249,17 @@ class HandlingTasks:
logger.info(f"Creating ERROR message: {message_text}") logger.info(f"Creating ERROR message: {message_text}")
logger.info(f"Message data: {message_data}") logger.info(f"Message data: {message_data}")
message = self.chatInterface.createWorkflowMessage(message_data) message = self.chatInterface.createMessage(message_data)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
logger.info(f"Message created: {action.execMethod}.{action.execAction}") logger.info(f"Message created: {action.execMethod}.{action.execAction}")
return message
else: else:
logger.error(f"Failed to create workflow message for action {action.execMethod}.{action.execAction}") logger.error(f"Failed to create workflow message for action {action.execMethod}.{action.execAction}")
return None
except Exception as e: except Exception as e:
logger.error(f"Error creating action message: {str(e)}") logger.error(f"Error creating action message: {str(e)}")
return None
# --- Helper validation methods --- # --- Helper validation methods ---

View file

@ -920,7 +920,7 @@ Please provide a comprehensive summary of this conversation."""
logger.error(f"Error during document access recovery for {document.id}: {str(e)}") logger.error(f"Error during document access recovery for {document.id}: {str(e)}")
return False return False
def createDocument(self, fileName: str, mimeType: str, content: str, base64encoded: bool = True) -> ChatDocument: def createDocument(self, fileName: str, mimeType: str, content: str, base64encoded: bool = True, messageId: str = None) -> ChatDocument:
"""Create document with file in one step - handles file creation internally""" """Create document with file in one step - handles file creation internally"""
# Convert content to bytes based on base64 flag # Convert content to bytes based on base64 flag
if base64encoded: if base64encoded:
@ -948,6 +948,7 @@ Please provide a comprehensive summary of this conversation."""
# Create document with all file attributes copied # Create document with all file attributes copied
document = ChatDocument( document = ChatDocument(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
messageId=messageId or "", # Use provided messageId or empty string as fallback
fileId=file_item.id, fileId=file_item.id,
fileName=file_info.get("fileName", fileName), fileName=file_info.get("fileName", fileName),
fileSize=file_info.get("size", 0), fileSize=file_info.get("size", 0),
@ -1060,7 +1061,7 @@ Please provide a comprehensive summary of this conversation."""
logger.error(f"Error executing method {methodName}.{actionName}: {str(e)}") logger.error(f"Error executing method {methodName}.{actionName}: {str(e)}")
raise raise
async def processFileIds(self, fileIds: List[str]) -> List[ChatDocument]: async def processFileIds(self, fileIds: List[str], messageId: str = None) -> List[ChatDocument]:
"""Process file IDs from existing files and return ChatDocument objects""" """Process file IDs from existing files and return ChatDocument objects"""
documents = [] documents = []
for fileId in fileIds: for fileId in fileIds:
@ -1071,6 +1072,7 @@ Please provide a comprehensive summary of this conversation."""
# Create document directly with all file attributes # Create document directly with all file attributes
document = ChatDocument( document = ChatDocument(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
messageId=messageId or "", # Use provided messageId or empty string as fallback
fileId=fileId, fileId=fileId,
fileName=fileInfo.get("fileName", "unknown"), fileName=fileInfo.get("fileName", "unknown"),
fileSize=fileInfo.get("size", 0), fileSize=fileInfo.get("size", 0),

View file

@ -33,9 +33,11 @@ class DatabaseConnector:
# Set userId (default to empty string if None) # Set userId (default to empty string if None)
self.userId = userId if userId is not None else "" self.userId = userId if userId is not None else ""
# Ensure the database directory exists # Initialize database system
self.initDbSystem()
# Set up database folder path
self.dbFolder = os.path.join(self.dbHost, self.dbDatabase) self.dbFolder = os.path.join(self.dbHost, self.dbDatabase)
os.makedirs(self.dbFolder, exist_ok=True)
# Cache for loaded data # Cache for loaded data
self._tablesCache: Dict[str, List[Dict[str, Any]]] = {} self._tablesCache: Dict[str, List[Dict[str, Any]]] = {}
@ -52,6 +54,17 @@ class DatabaseConnector:
logger.debug(f"Context: userId={self.userId}") logger.debug(f"Context: userId={self.userId}")
def initDbSystem(self):
"""Initialize the database system - creates necessary directories and structure."""
try:
# Ensure the database directory exists
self.dbFolder = os.path.join(self.dbHost, self.dbDatabase)
os.makedirs(self.dbFolder, exist_ok=True)
logger.info(f"Database system initialized: {self.dbFolder}")
except Exception as e:
logger.error(f"Error initializing database system: {e}")
raise
def _initializeSystemTable(self): def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet.""" """Initializes the system table if it doesn't exist yet."""
systemTablePath = self._getTablePath(self._systemTableName) systemTablePath = self._getTablePath(self._systemTableName)
@ -652,8 +665,14 @@ class DatabaseConnector:
except Exception as release_error: except Exception as release_error:
logger.error(f"Error releasing record lock for {recordPath}: {release_error}") logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
def getInitialId(self, table: str) -> Optional[str]: def getInitialId(self, table_or_model) -> Optional[str]:
"""Returns the initial ID for a table.""" """Returns the initial ID for a table."""
# Handle both string table names (legacy) and model classes (new)
if isinstance(table_or_model, str):
table = table_or_model
else:
table = table_or_model.__name__
systemData = self._loadSystemTable() systemData = self._loadSystemTable()
initialId = systemData.get(table) initialId = systemData.get(table)
logger.debug(f"Initial ID for table '{table}': {initialId}") logger.debug(f"Initial ID for table '{table}': {initialId}")

View file

@ -0,0 +1,840 @@
import psycopg2
import psycopg2.extras
import json
import os
import logging
from typing import List, Dict, Any, Optional, Union, get_origin, get_args
from datetime import datetime
import uuid
from pydantic import BaseModel
import threading
import time
from modules.shared.attributeUtils import to_dict
from modules.shared.timezoneUtils import get_utc_timestamp
from modules.shared.configuration import APP_CONFIG
from modules.interfaces.interfaceAppModel import SystemTable
logger = logging.getLogger(__name__)
# No mapping needed - table name = Pydantic model name exactly
def _get_model_fields(model_class) -> Dict[str, str]:
"""Get all fields from Pydantic model and map to SQL types."""
if not hasattr(model_class, '__fields__'):
return {}
fields = {}
for field_name, field_info in model_class.__fields__.items():
field_type = field_info.type_
# Check for JSONB fields (Dict, List, or complex types)
if (field_type == dict or
field_type == list or
(hasattr(field_type, '__origin__') and field_type.__origin__ in (dict, list)) or
field_name in ['execParameters', 'expectedDocumentFormats', 'resultDocuments', 'logs', 'messages', 'stats', 'tasks']):
fields[field_name] = 'JSONB'
# Simple type mapping
elif field_type in (str, type(None)) or (get_origin(field_type) is Union and type(None) in get_args(field_type)):
fields[field_name] = 'TEXT'
elif field_type == int:
fields[field_name] = 'INTEGER'
elif field_type == float:
fields[field_name] = 'REAL'
elif field_type == bool:
fields[field_name] = 'BOOLEAN'
else:
fields[field_name] = 'TEXT' # Default to TEXT
return fields
# No caching needed with proper database
class DatabaseConnector:
"""
A connector for PostgreSQL-based data storage.
Provides generic database operations without user/mandate filtering.
Uses PostgreSQL with JSONB columns for flexible data storage.
"""
def __init__(self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None):
# Store the input parameters
self.dbHost = dbHost
self.dbDatabase = dbDatabase
self.dbUser = dbUser
self.dbPassword = dbPassword
self.dbPort = dbPort
# Set userId (default to empty string if None)
self.userId = userId if userId is not None else ""
# Initialize database system first (creates database if needed)
self.connection = None
self.initDbSystem()
# No caching needed with proper database - PostgreSQL handles performance
# Thread safety
self._lock = threading.Lock()
# Initialize system table
self._systemTableName = "_system"
self._initializeSystemTable()
logger.debug(f"Context: userId={self.userId}")
def initDbSystem(self):
"""Initialize the database system - creates database and tables."""
try:
# Create database if it doesn't exist
self._create_database_if_not_exists()
# Create tables
self._create_tables()
# Establish connection to the database
self._connect()
logger.info("PostgreSQL database system initialized successfully")
except Exception as e:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise
def _create_database_if_not_exists(self):
"""Create the database if it doesn't exist."""
try:
# Use the configured user for database creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database="postgres",
user=self.dbUser,
password=self.dbPassword,
client_encoding='utf8'
)
conn.autocommit = True
with conn.cursor() as cursor:
# Check if database exists
cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,))
exists = cursor.fetchone()
if not exists:
# Create database
cursor.execute(f"CREATE DATABASE {self.dbDatabase}")
logger.info(f"Created database: {self.dbDatabase}")
else:
logger.info(f"Database {self.dbDatabase} already exists")
conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create database: {e}")
logger.error("Database connection failed - application cannot start")
raise RuntimeError(f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}")
def _create_tables(self):
"""Create only the system table - application tables are created by interfaces."""
try:
# Use the configured user for table creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding='utf8'
)
conn.autocommit = True
with conn.cursor() as cursor:
# Create only the system table
cursor.execute("""
CREATE TABLE IF NOT EXISTS _system (
id SERIAL PRIMARY KEY,
table_name VARCHAR(255) UNIQUE NOT NULL,
initial_id VARCHAR(255) NOT NULL,
_createdAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
_modifiedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
logger.info("System table created successfully")
conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
logger.error("Database system table creation failed - application cannot start")
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
def _connect(self):
"""Establish connection to PostgreSQL database."""
try:
# Use configured user for main connection with proper parameter handling
self.connection = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding='utf8',
cursor_factory=psycopg2.extras.RealDictCursor
)
self.connection.autocommit = False # Use transactions
logger.info(f"Connected to PostgreSQL database: {self.dbDatabase}")
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise
def _ensure_connection(self):
"""Ensure database connection is alive, reconnect if necessary."""
try:
if self.connection is None or self.connection.closed:
self._connect()
else:
# Test connection with a simple query
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
except Exception as e:
logger.warning(f"Connection lost, reconnecting: {e}")
self._connect()
def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet."""
try:
# First ensure the system table exists
self._ensureTableExists(SystemTable)
with self.connection.cursor() as cursor:
# Check if system table has any data
cursor.execute('SELECT COUNT(*) FROM "_system"')
row = cursor.fetchone()
count = row['count'] if row else 0
self.connection.commit()
except Exception as e:
logger.error(f"Error initializing system table: {e}")
self.connection.rollback()
raise
def _loadSystemTable(self) -> Dict[str, str]:
"""Loads the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
rows = cursor.fetchall()
system_data = {}
for row in rows:
system_data[row['table_name']] = row['initial_id']
return system_data
except Exception as e:
logger.error(f"Error loading system table: {e}")
return {}
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
"""Saves the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
# Clear existing data
cursor.execute('DELETE FROM "_system"')
# Insert new data
for table_name, initial_id in data.items():
cursor.execute("""
INSERT INTO "_system" ("table_name", "initial_id", "_modifiedAt")
VALUES (%s, %s, CURRENT_TIMESTAMP)
""", (table_name, initial_id))
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving system table: {e}")
self.connection.rollback()
return False
def _ensureSystemTableExists(self) -> bool:
"""Ensures the system table exists, creates it if it doesn't."""
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if system table exists
cursor.execute("SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,))
exists = cursor.fetchone()['count'] > 0
if not exists:
# Create system table
cursor.execute(f"""
CREATE TABLE "{self._systemTableName}" (
"table_name" VARCHAR(255) PRIMARY KEY,
"initial_id" VARCHAR(255),
"_createdAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"_modifiedAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
logger.info("System table created successfully")
else:
# Check if we need to add missing columns to existing table
cursor.execute("""
SELECT column_name FROM information_schema.columns
WHERE table_name = %s AND table_schema = 'public'
""", (self._systemTableName,))
existing_columns = [row['column_name'] for row in cursor.fetchall()]
if '_modifiedAt' not in existing_columns:
cursor.execute(f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "_modifiedAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP')
logger.info("Added _modifiedAt column to existing system table")
logger.debug("System table already exists")
return True
except Exception as e:
logger.error(f"Error ensuring system table exists: {e}")
return False
def _ensureTableExists(self, model_class: type) -> bool:
"""Ensures a table exists, creates it if it doesn't."""
table = model_class.__name__
if table == "SystemTable":
# Handle system table specially - it uses _system as the actual table name
return self._ensureSystemTableExists()
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if table exists by querying information_schema with case-insensitive search
cursor.execute('''
SELECT COUNT(*) FROM information_schema.tables
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
''', (table,))
exists = cursor.fetchone()['count'] > 0
logger.debug(f"Table {table} exists check: {exists}")
if not exists:
# Create table from Pydantic model
logger.debug(f"Creating table {table} with model {model_class}")
self._create_table_from_model(cursor, table, model_class)
logger.info(f"Created table '{table}' with columns from Pydantic model")
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error ensuring table {table} exists: {e}")
if hasattr(self, 'connection') and self.connection:
self.connection.rollback()
return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
"""Create table with columns matching Pydantic model fields."""
fields = _get_model_fields(model_class)
logger.debug(f"Creating table {table} with fields: {fields}")
# Build column definitions with quoted identifiers to preserve exact case
columns = ['"id" VARCHAR(255) PRIMARY KEY']
for field_name, sql_type in fields.items():
if field_name != 'id': # Skip id, already defined
columns.append(f'"{field_name}" {sql_type}')
# Add metadata columns
columns.extend([
'"_createdAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP',
'"_modifiedAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP',
'"_createdBy" VARCHAR(255)',
'"_modifiedBy" VARCHAR(255)'
])
# Create table
sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})'
logger.debug(f"Executing SQL: {sql}")
cursor.execute(sql)
# Create indexes for foreign keys
for field_name in fields:
if field_name.endswith('Id') and field_name != 'id':
cursor.execute(f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")')
def _save_record(self, cursor, table: str, recordId: str, record: Dict[str, Any], model_class: type) -> None:
"""Save record to normalized table with explicit columns."""
# Get columns from Pydantic model instead of database schema
fields = _get_model_fields(model_class)
columns = ['id'] + [field for field in fields.keys() if field != 'id'] + ['_createdAt', '_createdBy', '_modifiedAt', '_modifiedBy']
logger.debug(f"Table {table} columns: {columns}")
logger.debug(f"Record data: {record}")
if not columns:
logger.error(f"No columns found for table {table}")
return
# Filter record data to only include columns that exist in the table
filtered_record = {k: v for k, v in record.items() if k in columns}
# Ensure id is set
filtered_record['id'] = recordId
# Prepare values in the correct order
values = []
for col in columns:
value = filtered_record.get(col)
# Convert timestamp fields to proper PostgreSQL format
if col in ['_createdAt', '_modifiedAt'] and value is not None:
if isinstance(value, (int, float)):
# Convert Unix timestamp to PostgreSQL timestamp
from datetime import datetime
value = datetime.fromtimestamp(value)
elif isinstance(value, str):
# If it's already a string, try to parse it
try:
from datetime import datetime
value = datetime.fromtimestamp(float(value))
except:
pass # Keep as string if parsing fails
# Convert enum values to their string representation
elif hasattr(value, 'value'):
value = value.value
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
elif col in fields and fields[col] == 'JSONB' and value is not None:
import json
if isinstance(value, (dict, list)):
# Convert Python objects to JSON string for PostgreSQL JSONB
value = json.dumps(value)
elif isinstance(value, str):
# Validate that it's valid JSON, if not, try to parse and re-serialize
try:
# Test if it's already valid JSON
json.loads(value)
# If successful, keep as is
pass
except (json.JSONDecodeError, TypeError):
# If not valid JSON, convert to JSON string
value = json.dumps(value)
else:
# Convert other types to JSON
value = json.dumps(value)
values.append(value)
logger.debug(f"Values to insert: {values}")
# Build INSERT/UPDATE with quoted identifiers
col_names = ', '.join([f'"{col}"' for col in columns])
placeholders = ', '.join(['%s'] * len(columns))
updates = ', '.join([f'"{col}" = EXCLUDED."{col}"' for col in columns[1:] if col not in ['_createdAt', '_createdBy']])
sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}'
logger.debug(f"SQL: {sql}")
cursor.execute(sql, values)
def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
"""Loads a single record from the normalized table."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return None
with self.connection.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
row = cursor.fetchone()
if not row:
return None
# Convert row to dict and handle JSONB fields
record = dict(row)
fields = _get_model_fields(model_class)
# Parse JSONB fields back to Python objects
for field_name, field_type in fields.items():
if field_type == 'JSONB' and field_name in record and record[field_name] is not None:
import json
try:
if isinstance(record[field_name], str):
# Parse JSON string back to Python object
record[field_name] = json.loads(record[field_name])
elif isinstance(record[field_name], (dict, list)):
# Already a Python object, keep as is
pass
else:
# Try to parse as JSON
record[field_name] = json.loads(str(record[field_name]))
except (json.JSONDecodeError, TypeError, ValueError):
# If parsing fails, keep as string
logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}")
pass
return record
except Exception as e:
logger.error(f"Error loading record {recordId} from table {table}: {e}")
return None
def _saveRecord(self, model_class: type, recordId: str, record: Dict[str, Any]) -> bool:
"""Saves a single record to the table."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return False
recordId = str(recordId)
if "id" in record and str(record["id"]) != recordId:
raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}")
# Add metadata
currentTime = get_utc_timestamp()
if "_createdAt" not in record:
record["_createdAt"] = currentTime
record["_createdBy"] = self.userId
record["_modifiedAt"] = currentTime
record["_modifiedBy"] = self.userId
with self.connection.cursor() as cursor:
self._save_record(cursor, table, recordId, record, model_class)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving record {recordId} to table {table}: {e}")
self.connection.rollback()
return False
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
"""Loads all records from a normalized table."""
table = model_class.__name__
if table == self._systemTableName:
return self._loadSystemTable()
try:
if not self._ensureTableExists(model_class):
return []
with self.connection.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
records = [dict(row) for row in cursor.fetchall()]
# Handle JSONB fields for all records
fields = _get_model_fields(model_class)
for record in records:
for field_name, field_type in fields.items():
if field_type == 'JSONB' and field_name in record and record[field_name] is not None:
import json
try:
if isinstance(record[field_name], str):
# Parse JSON string back to Python object
record[field_name] = json.loads(record[field_name])
elif isinstance(record[field_name], (dict, list)):
# Already a Python object, keep as is
pass
else:
# Try to parse as JSON
record[field_name] = json.loads(str(record[field_name]))
except (json.JSONDecodeError, TypeError, ValueError):
# If parsing fails, keep as string
logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}")
pass
return records
except Exception as e:
logger.error(f"Error loading table {table}: {e}")
return []
def _applyRecordFilter(self, records: List[Dict[str, Any]], recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Applies a record filter to the records"""
if not recordFilter:
return records
filteredRecords = []
for record in records:
match = True
for field, value in recordFilter.items():
# Check if the field exists
if field not in record:
match = False
break
# Convert both values to strings for comparison
recordValue = str(record[field])
filterValue = str(value)
# Direct string comparison
if recordValue != filterValue:
match = False
break
if match:
filteredRecords.append(record)
return filteredRecords
def _registerInitialId(self, table: str, initialId: str) -> bool:
"""Registers the initial ID for a table."""
try:
systemData = self._loadSystemTable()
if table not in systemData:
systemData[table] = initialId
success = self._saveSystemTable(systemData)
if success:
logger.info(f"Initial ID {initialId} for table {table} registered")
return success
else:
# Check if the existing initial ID still exists in the table
existingInitialId = systemData[table]
records = self.getRecordset(model_class, recordFilter={"id": existingInitialId})
if not records:
# The initial record no longer exists, update to the new one
systemData[table] = initialId
success = self._saveSystemTable(systemData)
if success:
logger.info(f"Initial ID updated from {existingInitialId} to {initialId} for table {table}")
return success
else:
logger.debug(f"Initial ID {existingInitialId} for table {table} already exists and is valid")
return True
except Exception as e:
logger.error(f"Error registering the initial ID for table {table}: {e}")
return False
def _removeInitialId(self, table: str) -> bool:
"""Removes the initial ID for a table from the system table."""
try:
systemData = self._loadSystemTable()
if table in systemData:
del systemData[table]
success = self._saveSystemTable(systemData)
if success:
logger.info(f"Initial ID for table {table} removed from system table")
return success
return True # If not present, this is not an error
except Exception as e:
logger.error(f"Error removing initial ID for table {table}: {e}")
return False
def updateContext(self, userId: str) -> None:
"""Updates the context of the database connector."""
if userId is None:
raise ValueError("userId must be provided")
self.userId = userId
logger.info(f"Updated database context: userId={self.userId}")
# No cache to clear - database handles data consistency
def clearTableCache(self, model_class: type) -> None:
"""No-op: Database handles data consistency automatically."""
# No caching with proper database - PostgreSQL handles consistency
pass
# Public API
def getTables(self) -> List[str]:
"""Returns a list of all available tables."""
tables = []
try:
with self.connection.cursor() as cursor:
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name NOT LIKE '_%'
ORDER BY table_name
""")
rows = cursor.fetchall()
tables = [row['table_name'] for row in rows]
except Exception as e:
logger.error(f"Error reading the database: {e}")
return tables
def getFields(self, model_class: type) -> List[str]:
"""Returns a list of all fields in a table."""
data = self._loadTable(model_class)
if not data:
return []
fields = list(data[0].keys()) if data else []
return fields
def getSchema(self, model_class: type, language: str = None) -> Dict[str, Dict[str, Any]]:
"""Returns a schema object for a table with data types and labels."""
data = self._loadTable(model_class)
schema = {}
if not data:
return schema
firstRecord = data[0]
for field, value in firstRecord.items():
dataType = type(value).__name__
label = field
schema[field] = {
"type": dataType,
"label": label
}
return schema
def getRecordset(self, model_class: type, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Returns a list of records from a table, filtered by criteria."""
table = model_class.__name__
# If we have specific record IDs in the filter, only load those records
if recordFilter and "id" in recordFilter:
recordId = recordFilter["id"]
record = self._loadRecord(model_class, recordId)
if record:
records = [record]
else:
return []
else:
# Load all records if no specific ID filter
records = self._loadTable(model_class)
# Apply recordFilter if available
if recordFilter:
records = self._applyRecordFilter(records, recordFilter)
# If fieldFilter is available, reduce the fields
if fieldFilter and isinstance(fieldFilter, list):
result = []
for record in records:
filteredRecord = {}
for field in fieldFilter:
if field in record:
filteredRecord[field] = record[field]
result.append(filteredRecord)
return result
return records
def recordCreate(self, model_class: type, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
"""Creates a new record in a table based on Pydantic model class."""
# If record is a Pydantic model, convert to dict
if isinstance(record, BaseModel):
record = to_dict(record)
elif isinstance(record, dict):
record = record.copy()
else:
raise ValueError("Record must be a Pydantic model or dictionary")
# Ensure record has an ID
if "id" not in record:
record["id"] = str(uuid.uuid4())
# Save record
self._saveRecord(model_class, record["id"], record)
# Check if this is the first record in the table and register as initial ID
table = model_class.__name__
existingInitialId = self.getInitialId(model_class)
if existingInitialId is None:
# This is the first record, register it as the initial ID
self._registerInitialId(table, record["id"])
logger.info(f"Registered initial ID {record['id']} for table {table}")
return record
def recordModify(self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
"""Modifies an existing record in a table based on Pydantic model class."""
# Load existing record
existingRecord = self._loadRecord(model_class, recordId)
if not existingRecord:
table = model_class.__name__
raise ValueError(f"Record {recordId} not found in table {table}")
# If record is a Pydantic model, convert to dict
if isinstance(record, BaseModel):
record = to_dict(record)
elif isinstance(record, dict):
record = record.copy()
else:
raise ValueError("Record must be a Pydantic model or dictionary")
# CRITICAL: Ensure we never modify the ID
if "id" in record and str(record["id"]) != recordId:
logger.error(f"Attempted to modify record ID from {recordId} to {record['id']}")
raise ValueError("Cannot modify record ID - it must match the provided recordId")
# Update existing record with new data
existingRecord.update(record)
# Save updated record
self._saveRecord(model_class, recordId, existingRecord)
return existingRecord
def recordDelete(self, model_class: type, recordId: str) -> bool:
"""Deletes a record from the table based on Pydantic model class."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return False
with self.connection.cursor() as cursor:
# Check if record exists
cursor.execute(f"SELECT id FROM {table} WHERE id = %s", (recordId,))
if not cursor.fetchone():
return False
# Check if it's an initial record
initialId = self.getInitialId(model_class)
if initialId is not None and initialId == recordId:
self._removeInitialId(table)
logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table")
# Delete the record
cursor.execute(f"DELETE FROM {table} WHERE id = %s", (recordId,))
# No cache to update - database handles consistency
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
self.connection.rollback()
return False
def getInitialId(self, model_class: type) -> Optional[str]:
"""Returns the initial ID for a table."""
table = model_class.__name__
systemData = self._loadSystemTable()
initialId = systemData.get(table)
logger.debug(f"Initial ID for table '{table}': {initialId}")
return initialId
def close(self):
"""Close the database connection."""
if hasattr(self, 'connection') and self.connection and not self.connection.closed:
self.connection.close()
logger.debug("Database connection closed")
def __del__(self):
"""Cleanup method to close connection."""
try:
self.close()
except Exception:
# Ignore errors during cleanup
pass

View file

@ -1,178 +0,0 @@
import threading
import queue
import time
import logging
from typing import Optional, Dict, Any
from .connectorDbJson import DatabaseConnector
logger = logging.getLogger(__name__)
class DatabaseConnectorPool:
"""
A connection pool for DatabaseConnector instances to manage resources efficiently
and ensure proper isolation between users.
"""
def __init__(self, max_connections: int = 100, max_idle_time: int = 300):
"""
Initialize the connection pool.
Args:
max_connections: Maximum number of connections in the pool
max_idle_time: Maximum idle time in seconds before connection is considered stale
"""
self.max_connections = max_connections
self.max_idle_time = max_idle_time
self._pool = queue.Queue(maxsize=max_connections)
self._created_connections = 0
self._lock = threading.Lock()
self._connection_times = {} # Track when connections were created
def _create_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
"""Create a new DatabaseConnector instance."""
with self._lock:
if self._created_connections >= self.max_connections:
raise RuntimeError(f"Maximum connections ({self.max_connections}) exceeded")
self._created_connections += 1
logger.debug(f"Creating new database connector (total: {self._created_connections})")
connector = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
dbUser=dbUser,
dbPassword=dbPassword,
userId=userId
)
# Track creation time
connector_id = id(connector)
self._connection_times[connector_id] = time.time()
return connector
def get_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
"""
Get a database connector from the pool or create a new one.
Args:
dbHost: Database host path
dbDatabase: Database name
dbUser: Database user (optional)
dbPassword: Database password (optional)
userId: User ID for context (optional)
Returns:
DatabaseConnector instance
"""
try:
# Try to get an existing connector from the pool
connector = self._pool.get_nowait()
# Check if connector is stale
connector_id = id(connector)
if connector_id in self._connection_times:
idle_time = time.time() - self._connection_times[connector_id]
if idle_time > self.max_idle_time:
logger.debug(f"Connector {connector_id} is stale (idle: {idle_time}s), creating new one")
# Remove stale connector from tracking
if connector_id in self._connection_times:
del self._connection_times[connector_id]
# Create new connector
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
# Update user context if provided
if userId is not None:
connector.updateContext(userId)
logger.debug(f"Reusing existing connector {connector_id}")
return connector
except queue.Empty:
# Pool is empty, create new connector
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
def return_connector(self, connector: DatabaseConnector) -> None:
"""
Return a connector to the pool for reuse.
Args:
connector: DatabaseConnector instance to return
"""
try:
# Update connection time
connector_id = id(connector)
self._connection_times[connector_id] = time.time()
# Try to return to pool
self._pool.put_nowait(connector)
logger.debug(f"Returned connector {connector_id} to pool")
except queue.Full:
# Pool is full, discard connector
logger.debug(f"Pool full, discarding connector {id(connector)}")
with self._lock:
self._created_connections -= 1
if id(connector) in self._connection_times:
del self._connection_times[id(connector)]
def cleanup_stale_connections(self) -> int:
"""
Clean up stale connections from the pool.
Returns:
Number of connections cleaned up
"""
cleaned = 0
current_time = time.time()
# Check all tracked connections
stale_connectors = []
for connector_id, creation_time in list(self._connection_times.items()):
if current_time - creation_time > self.max_idle_time:
stale_connectors.append(connector_id)
# Remove stale connections from tracking
for connector_id in stale_connectors:
if connector_id in self._connection_times:
del self._connection_times[connector_id]
cleaned += 1
logger.debug(f"Cleaned up {cleaned} stale connections")
return cleaned
def get_stats(self) -> Dict[str, Any]:
"""Get pool statistics."""
with self._lock:
return {
"max_connections": self.max_connections,
"created_connections": self._created_connections,
"available_connections": self._pool.qsize(),
"tracked_connections": len(self._connection_times)
}
# Global pool instance
_connector_pool = None
_pool_lock = threading.Lock()
def get_connector_pool() -> DatabaseConnectorPool:
"""Get the global connector pool instance."""
global _connector_pool
if _connector_pool is None:
with _pool_lock:
if _connector_pool is None:
_connector_pool = DatabaseConnectorPool()
return _connector_pool
def get_connector(dbHost: str, dbDatabase: str, dbUser: str = None,
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
"""Get a database connector from the global pool."""
pool = get_connector_pool()
return pool.get_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
def return_connector(connector: DatabaseConnector) -> None:
"""Return a database connector to the global pool."""
pool = get_connector_pool()
pool.return_connector(connector)

View file

@ -5,7 +5,7 @@ Access control for the Application.
import logging import logging
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime from datetime import datetime
from modules.interfaces.interfaceAppModel import UserPrivilege, User from modules.interfaces.interfaceAppModel import UserPrivilege, User, UserInDB, AuthEvent
from modules.shared.timezoneUtils import get_utc_now from modules.shared.timezoneUtils import get_utc_now
# Configure logger # Configure logger
@ -29,28 +29,29 @@ class AppAccess:
self.db = db self.db = db
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Unified user access management function that filters data based on user privileges Unified user access management function that filters data based on user privileges
and adds access control attributes. and adds access control attributes.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordset: Recordset to filter based on access rules recordset: Recordset to filter based on access rules
Returns: Returns:
Filtered recordset with access control attributes Filtered recordset with access control attributes
""" """
filtered_records = [] filtered_records = []
table_name = model_class.__name__
# Only SYSADMIN can see mandates # Only SYSADMIN can see mandates
if table == "mandates": if table_name == "Mandate":
if self.privilege == UserPrivilege.SYSADMIN: if self.privilege == UserPrivilege.SYSADMIN:
filtered_records = recordset filtered_records = recordset
else: else:
filtered_records = [] filtered_records = []
# Special handling for users table # Special handling for users table
elif table == "users": elif table_name == "UserInDB":
if self.privilege == UserPrivilege.SYSADMIN: if self.privilege == UserPrivilege.SYSADMIN:
# SysAdmin sees all users # SysAdmin sees all users
filtered_records = recordset filtered_records = recordset
@ -61,13 +62,13 @@ class AppAccess:
# Regular users only see themselves # Regular users only see themselves
filtered_records = [r for r in recordset if r.get("id") == self.userId] filtered_records = [r for r in recordset if r.get("id") == self.userId]
# Special handling for connections table # Special handling for connections table
elif table == "connections": elif table_name == "UserConnection":
if self.privilege == UserPrivilege.SYSADMIN: if self.privilege == UserPrivilege.SYSADMIN:
# SysAdmin sees all connections # SysAdmin sees all connections
filtered_records = recordset filtered_records = recordset
elif self.privilege == UserPrivilege.ADMIN: elif self.privilege == UserPrivilege.ADMIN:
# Admin sees connections for users in their mandate # Admin sees connections for users in their mandate
users: List[Dict[str, Any]] = self.db.getRecordset("users", recordFilter={"mandateId": self.mandateId}) users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
user_ids: List[str] = [str(u["id"]) for u in users] user_ids: List[str] = [str(u["id"]) for u in users]
filtered_records = [r for r in recordset if r.get("userId") in user_ids] filtered_records = [r for r in recordset if r.get("userId") in user_ids]
else: else:
@ -89,11 +90,11 @@ class AppAccess:
record_id = record.get("id") record_id = record.get("id")
# Set access control flags based on user permissions # Set access control flags based on user permissions
if table == "mandates": if table_name == "Mandate":
record["_hideView"] = False # SYSADMIN can view record["_hideView"] = False # SYSADMIN can view
record["_hideEdit"] = not self.canModify("mandates", record_id) record["_hideEdit"] = not self.canModify(Mandate, record_id)
record["_hideDelete"] = not self.canModify("mandates", record_id) record["_hideDelete"] = not self.canModify(Mandate, record_id)
elif table == "users": elif table_name == "UserInDB":
record["_hideView"] = False # Everyone can view users they have access to record["_hideView"] = False # Everyone can view users they have access to
# SysAdmin can edit/delete any user # SysAdmin can edit/delete any user
if self.privilege == UserPrivilege.SYSADMIN: if self.privilege == UserPrivilege.SYSADMIN:
@ -107,7 +108,7 @@ class AppAccess:
else: else:
record["_hideEdit"] = record.get("id") != self.userId record["_hideEdit"] = record.get("id") != self.userId
record["_hideDelete"] = True # Regular users cannot delete users record["_hideDelete"] = True # Regular users cannot delete users
elif table == "connections": elif table_name == "UserConnection":
# Everyone can view connections they have access to # Everyone can view connections they have access to
record["_hideView"] = False record["_hideView"] = False
# SysAdmin can edit/delete any connection # SysAdmin can edit/delete any connection
@ -116,7 +117,7 @@ class AppAccess:
record["_hideDelete"] = False record["_hideDelete"] = False
# Admin can edit/delete connections for users in their mandate # Admin can edit/delete connections for users in their mandate
elif self.privilege == UserPrivilege.ADMIN: elif self.privilege == UserPrivilege.ADMIN:
users: List[Dict[str, Any]] = self.db.getRecordset("users", recordFilter={"mandateId": self.mandateId}) users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
user_ids: List[str] = [str(u["id"]) for u in users] user_ids: List[str] = [str(u["id"]) for u in users]
record["_hideEdit"] = record.get("userId") not in user_ids record["_hideEdit"] = record.get("userId") not in user_ids
record["_hideDelete"] = record.get("userId") not in user_ids record["_hideDelete"] = record.get("userId") not in user_ids
@ -125,35 +126,37 @@ class AppAccess:
record["_hideEdit"] = record.get("userId") != self.userId record["_hideEdit"] = record.get("userId") != self.userId
record["_hideDelete"] = record.get("userId") != self.userId record["_hideDelete"] = record.get("userId") != self.userId
elif table == "auth_events": elif table_name == "AuthEvent":
# Only show auth events for the current user or if admin # Only show auth events for the current user or if admin
if self.privilege in [UserPrivilege.SYSADMIN, UserPrivilege.ADMIN]: if self.privilege in [UserPrivilege.SYSADMIN, UserPrivilege.ADMIN]:
record["_hideView"] = False record["_hideView"] = False
else: else:
record["_hideView"] = record.get("userId") != self.userId record["_hideView"] = record.get("userId") != self.userId
record["_hideEdit"] = True # Auth events can't be edited record["_hideEdit"] = True # Auth events can't be edited
record["_hideDelete"] = not self.canModify("auth_events", record_id) record["_hideDelete"] = not self.canModify(AuthEvent, record_id)
else: else:
# Default access control for other tables # Default access control for other tables
record["_hideView"] = False record["_hideView"] = False
record["_hideEdit"] = not self.canModify(table, record_id) record["_hideEdit"] = not self.canModify(model_class, record_id)
record["_hideDelete"] = not self.canModify(table, record_id) record["_hideDelete"] = not self.canModify(model_class, record_id)
return filtered_records return filtered_records
def canModify(self, table: str, recordId: Optional[str] = None) -> bool: def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
""" """
Checks if the current user can modify (create/update/delete) records in a table. Checks if the current user can modify (create/update/delete) records in a table.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordId: Optional record ID for specific record check recordId: Optional record ID for specific record check
Returns: Returns:
Boolean indicating permission Boolean indicating permission
""" """
table_name = model_class.__name__
# For mandates, only SYSADMIN can modify # For mandates, only SYSADMIN can modify
if table == "mandates": if table_name == "Mandate":
return self.privilege == UserPrivilege.SYSADMIN return self.privilege == UserPrivilege.SYSADMIN
# System admins can modify anything else # System admins can modify anything else
@ -163,17 +166,17 @@ class AppAccess:
# Check specific record permissions # Check specific record permissions
if recordId is not None: if recordId is not None:
# Get the record to check ownership # Get the record to check ownership
records: List[Dict[str, Any]] = self.db.getRecordset(table, recordFilter={"id": str(recordId)}) records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": str(recordId)})
if not records: if not records:
return False return False
record = records[0] record = records[0]
# Special handling for connections # Special handling for connections
if table == "connections": if table_name == "UserConnection":
# Admin can modify connections for users in their mandate # Admin can modify connections for users in their mandate
if self.privilege == UserPrivilege.ADMIN: if self.privilege == UserPrivilege.ADMIN:
users: List[Dict[str, Any]] = self.db.getRecordset("users", recordFilter={"mandateId": self.mandateId}) users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
user_ids: List[str] = [str(u["id"]) for u in users] user_ids: List[str] = [str(u["id"]) for u in users]
return record.get("userId") in user_ids return record.get("userId") in user_ids
# Users can only modify their own connections # Users can only modify their own connections

View file

@ -354,3 +354,92 @@ class MsftToken(Token):
"""Microsoft OAuth token model""" """Microsoft OAuth token model"""
pass pass
class AuthEvent(BaseModel, ModelMixin):
"""Data model for authentication events"""
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the auth event",
frontend_type="text",
frontend_readonly=True,
frontend_required=False
)
userId: str = Field(
description="ID of the user this event belongs to",
frontend_type="text",
frontend_readonly=True,
frontend_required=True
)
eventType: str = Field(
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
frontend_type="text",
frontend_readonly=True,
frontend_required=True
)
timestamp: float = Field(
default_factory=get_utc_timestamp,
description="Unix timestamp when the event occurred",
frontend_type="datetime",
frontend_readonly=True,
frontend_required=True
)
ipAddress: Optional[str] = Field(
default=None,
description="IP address from which the event originated",
frontend_type="text",
frontend_readonly=True,
frontend_required=False
)
userAgent: Optional[str] = Field(
default=None,
description="User agent string from the request",
frontend_type="text",
frontend_readonly=True,
frontend_required=False
)
success: bool = Field(
default=True,
description="Whether the authentication event was successful",
frontend_type="boolean",
frontend_readonly=True,
frontend_required=True
)
details: Optional[str] = Field(
default=None,
description="Additional details about the event",
frontend_type="text",
frontend_readonly=True,
frontend_required=False
)
# Register labels for AuthEvent
register_model_labels(
"AuthEvent",
{"en": "Authentication Event", "fr": "Événement d'authentification"},
{
"id": {"en": "ID", "fr": "ID"},
"userId": {"en": "User ID", "fr": "ID utilisateur"},
"eventType": {"en": "Event Type", "fr": "Type d'événement"},
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
"ipAddress": {"en": "IP Address", "fr": "Adresse IP"},
"userAgent": {"en": "User Agent", "fr": "Agent utilisateur"},
"success": {"en": "Success", "fr": "Succès"},
"details": {"en": "Details", "fr": "Détails"}
}
)
class SystemTable(BaseModel, ModelMixin):
"""Data model for system table entries"""
table_name: str = Field(
description="Name of the table",
frontend_type="text",
frontend_readonly=True,
frontend_required=True
)
initial_id: Optional[str] = Field(
default=None,
description="Initial ID for the table",
frontend_type="text",
frontend_readonly=True,
frontend_required=False
)

View file

@ -12,15 +12,14 @@ import json
from passlib.context import CryptContext from passlib.context import CryptContext
import uuid import uuid
from modules.connectors.connectorDbJson import DatabaseConnector from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.connectors.connectorPool import get_connector, return_connector
from modules.shared.configuration import APP_CONFIG from modules.shared.configuration import APP_CONFIG
from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp
from modules.interfaces.interfaceAppAccess import AppAccess from modules.interfaces.interfaceAppAccess import AppAccess
from modules.interfaces.interfaceAppModel import ( from modules.interfaces.interfaceAppModel import (
User, Mandate, UserInDB, UserConnection, User, Mandate, UserInDB, UserConnection,
AuthAuthority, UserPrivilege, AuthAuthority, UserPrivilege,
ConnectionStatus, Token ConnectionStatus, Token, AuthEvent
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,34 +80,36 @@ class AppObjects:
self.db.updateContext(self.userId) self.db.updateContext(self.userId)
def __del__(self): def __del__(self):
"""Cleanup method to return connector to pool.""" """Cleanup method to close database connection."""
if hasattr(self, 'db') and self.db is not None: if hasattr(self, 'db') and self.db is not None:
try: try:
return_connector(self.db) self.db.close()
except Exception as e: except Exception as e:
logger.error(f"Error returning connector to pool: {e}") logger.error(f"Error closing database connection: {e}")
def _initializeDatabase(self): def _initializeDatabase(self):
"""Initializes the database connection using connection pool.""" """Initializes the database connection directly."""
try: try:
# Get configuration values with defaults # Get configuration values with defaults
dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data") dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data")
dbDatabase = APP_CONFIG.get("DB_APP_DATABASE", "app") dbDatabase = APP_CONFIG.get("DB_APP_DATABASE", "app")
dbUser = APP_CONFIG.get("DB_APP_USER") dbUser = APP_CONFIG.get("DB_APP_USER")
dbPassword = APP_CONFIG.get("DB_APP_PASSWORD_SECRET") dbPassword = APP_CONFIG.get("DB_APP_PASSWORD_SECRET")
dbPort = int(APP_CONFIG.get("DB_APP_PORT", 5432))
# Ensure the database directory exists # Create database connector directly
os.makedirs(dbHost, exist_ok=True) self.db = DatabaseConnector(
# Get connector from pool with user context
self.db = get_connector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,
dbPassword=dbPassword, dbPassword=dbPassword,
dbPort=dbPort,
userId=self.userId userId=self.userId
) )
# Initialize database system
self.db.initDbSystem()
logger.info(f"Database initialized successfully for user {self.userId}") logger.info(f"Database initialized successfully for user {self.userId}")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize database: {str(e)}") logger.error(f"Failed to initialize database: {str(e)}")
@ -121,8 +122,8 @@ class AppObjects:
def _initRootMandate(self): def _initRootMandate(self):
"""Creates the Root mandate if it doesn't exist.""" """Creates the Root mandate if it doesn't exist."""
existingMandateId = self.getInitialId("mandates") existingMandateId = self.getInitialId(Mandate)
mandates = self.db.getRecordset("mandates") mandates = self.db.getRecordset(Mandate)
if existingMandateId is None or not mandates: if existingMandateId is None or not mandates:
logger.info("Creating Root mandate") logger.info("Creating Root mandate")
rootMandate = Mandate( rootMandate = Mandate(
@ -130,23 +131,20 @@ class AppObjects:
language="en", language="en",
enabled=True enabled=True
) )
createdMandate = self.db.recordCreate("mandates", rootMandate.to_dict()) createdMandate = self.db.recordCreate(Mandate, rootMandate)
logger.info(f"Root mandate created with ID {createdMandate['id']}") logger.info(f"Root mandate created with ID {createdMandate['id']}")
# Register the initial ID
self.db._registerInitialId("mandates", createdMandate['id'])
# Update mandate context # Update mandate context
self.mandateId = createdMandate['id'] self.mandateId = createdMandate['id']
def _initAdminUser(self): def _initAdminUser(self):
"""Creates the Admin user if it doesn't exist.""" """Creates the Admin user if it doesn't exist."""
existingUserId = self.getInitialId("users") existingUserId = self.getInitialId(UserInDB)
users = self.db.getRecordset("users") users = self.db.getRecordset(UserInDB)
if existingUserId is None or not users: if existingUserId is None or not users:
logger.info("Creating Admin user") logger.info("Creating Admin user")
adminUser = UserInDB( adminUser = UserInDB(
mandateId=self.getInitialId("mandates"), mandateId=self.getInitialId(Mandate),
username="admin", username="admin",
email="admin@example.com", email="admin@example.com",
fullName="Administrator", fullName="Administrator",
@ -157,30 +155,27 @@ class AppObjects:
hashedPassword=self._getPasswordHash("The 1st Poweron Admin"), # Use a secure password in production! hashedPassword=self._getPasswordHash("The 1st Poweron Admin"), # Use a secure password in production!
connections=[] connections=[]
) )
createdUser = self.db.recordCreate("users", adminUser.to_dict()) createdUser = self.db.recordCreate(UserInDB, adminUser)
logger.info(f"Admin user created with ID {createdUser['id']}") logger.info(f"Admin user created with ID {createdUser['id']}")
# Register the initial ID
self.db._registerInitialId("users", createdUser['id'])
# Update user context # Update user context
self.currentUser = createdUser self.currentUser = createdUser
self.userId = createdUser.get("id") self.userId = createdUser.get("id")
def _uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Unified user access management function that filters data based on user privileges Unified user access management function that filters data based on user privileges
and adds access control attributes. and adds access control attributes.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordset: Recordset to filter based on access rules recordset: Recordset to filter based on access rules
Returns: Returns:
Filtered recordset with access control attributes Filtered recordset with access control attributes
""" """
# First apply access control # First apply access control
filteredRecords = self.access.uam(table, recordset) filteredRecords = self.access.uam(model_class, recordset)
# Then filter out database-specific fields # Then filter out database-specific fields
cleanedRecords = [] cleanedRecords = []
@ -191,26 +186,23 @@ class AppObjects:
return cleanedRecords return cleanedRecords
def _canModify(self, table: str, recordId: Optional[str] = None) -> bool: def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
""" """
Checks if the current user can modify (create/update/delete) records in a table. Checks if the current user can modify (create/update/delete) records in a table.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordId: Optional record ID for specific record check recordId: Optional record ID for specific record check
Returns: Returns:
Boolean indicating permission Boolean indicating permission
""" """
return self.access.canModify(table, recordId) return self.access.canModify(model_class, recordId)
def _clearTableCache(self, table: str) -> None:
"""Clears the cache for a specific table to ensure fresh data."""
self.db.clearTableCache(table)
def getInitialId(self, table: str) -> Optional[str]: def getInitialId(self, model_class: type) -> Optional[str]:
"""Returns the initial ID for a table.""" """Returns the initial ID for a table."""
return self.db.getInitialId(table) return self.db.getInitialId(model_class)
def _getPasswordHash(self, password: str) -> str: def _getPasswordHash(self, password: str) -> str:
"""Creates a hash for a password.""" """Creates a hash for a password."""
@ -225,8 +217,8 @@ class AppObjects:
def getUsersByMandate(self, mandateId: str) -> List[User]: def getUsersByMandate(self, mandateId: str) -> List[User]:
"""Returns users for a specific mandate if user has access.""" """Returns users for a specific mandate if user has access."""
# Get users for this mandate # Get users for this mandate
users = self.db.getRecordset("users", recordFilter={"mandateId": mandateId}) users = self.db.getRecordset(UserInDB, recordFilter={"mandateId": mandateId})
filteredUsers = self._uam("users", users) filteredUsers = self._uam(UserInDB, users)
# Convert to User models # Convert to User models
return [User.from_dict(user) for user in filteredUsers] return [User.from_dict(user) for user in filteredUsers]
@ -235,7 +227,7 @@ class AppObjects:
"""Returns a user by username.""" """Returns a user by username."""
try: try:
# Get users table # Get users table
users = self.db.getRecordset("users") users = self.db.getRecordset(UserInDB)
if not users: if not users:
return None return None
@ -255,7 +247,7 @@ class AppObjects:
"""Returns a user by ID if user has access.""" """Returns a user by ID if user has access."""
try: try:
# Get all users # Get all users
users = self.db.getRecordset("users") users = self.db.getRecordset(UserInDB)
if not users: if not users:
return None return None
@ -263,7 +255,7 @@ class AppObjects:
for user_dict in users: for user_dict in users:
if user_dict.get("id") == userId: if user_dict.get("id") == userId:
# Apply access control # Apply access control
filteredUsers = self._uam("users", [user_dict]) filteredUsers = self._uam(UserInDB, [user_dict])
if filteredUsers: if filteredUsers:
return User.from_dict(filteredUsers[0]) return User.from_dict(filteredUsers[0])
return None return None
@ -278,7 +270,7 @@ class AppObjects:
"""Returns all connections for a user.""" """Returns all connections for a user."""
try: try:
# Get connections for this user # Get connections for this user
connections = self.db.getRecordset("connections", recordFilter={"userId": userId}) connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
# Convert to UserConnection objects # Convert to UserConnection objects
result = [] result = []
@ -345,10 +337,8 @@ class AppObjects:
) )
# Save to connections table # Save to connections table
self.db.recordCreate("connections", connection.to_dict()) self.db.recordCreate(UserConnection, connection)
# Clear cache to ensure fresh data
self._clearTableCache("connections")
return connection return connection
@ -360,7 +350,7 @@ class AppObjects:
"""Remove a connection to an external service""" """Remove a connection to an external service"""
try: try:
# Get connection # Get connection
connections = self.db.getRecordset("connections", recordFilter={ connections = self.db.getRecordset(UserConnection, recordFilter={
"id": connectionId "id": connectionId
}) })
@ -368,10 +358,8 @@ class AppObjects:
raise ValueError(f"Connection {connectionId} not found") raise ValueError(f"Connection {connectionId} not found")
# Delete connection # Delete connection
self.db.recordDelete("connections", connectionId) self.db.recordDelete(UserConnection, connectionId)
# Clear cache to ensure fresh data
self._clearTableCache("connections")
except Exception as e: except Exception as e:
logger.error(f"Error removing user connection: {str(e)}") logger.error(f"Error removing user connection: {str(e)}")
@ -380,7 +368,6 @@ class AppObjects:
def authenticateLocalUser(self, username: str, password: str) -> Optional[User]: def authenticateLocalUser(self, username: str, password: str) -> Optional[User]:
"""Authenticates a user by username and password using local authentication.""" """Authenticates a user by username and password using local authentication."""
# Clear the users table from cache and reload it # Clear the users table from cache and reload it
self._clearTableCache("users")
# Get user by username # Get user by username
user = self.getUserByUsername(username) user = self.getUserByUsername(username)
@ -397,7 +384,7 @@ class AppObjects:
raise ValueError("User does not have local authentication enabled") raise ValueError("User does not have local authentication enabled")
# Get the full user record with password hash for verification # Get the full user record with password hash for verification
userRecord = self.db.getRecordset("users", recordFilter={"id": user.id})[0] userRecord = self.db.getRecordset(UserInDB, recordFilter={"id": user.id})[0]
if not userRecord.get("hashedPassword"): if not userRecord.get("hashedPassword"):
raise ValueError("User has no password set") raise ValueError("User has no password set")
@ -441,12 +428,10 @@ class AppObjects:
) )
# Create user record # Create user record
createdRecord = self.db.recordCreate("users", userData.to_dict()) createdRecord = self.db.recordCreate(UserInDB, userData)
if not createdRecord or not createdRecord.get("id"): if not createdRecord or not createdRecord.get("id"):
raise ValueError("Failed to create user record") raise ValueError("Failed to create user record")
# Clear cache to ensure fresh data
self._clearTableCache("users")
# Add external connection if provided # Add external connection if provided
if externalId and externalUsername: if externalId and externalUsername:
@ -459,12 +444,11 @@ class AppObjects:
) )
# Get created user using the returned ID # Get created user using the returned ID
createdUser = self.db.getRecordset("users", recordFilter={"id": createdRecord["id"]}) createdUser = self.db.getRecordset(UserInDB, recordFilter={"id": createdRecord["id"]})
if not createdUser or len(createdUser) == 0: if not createdUser or len(createdUser) == 0:
raise ValueError("Failed to retrieve created user") raise ValueError("Failed to retrieve created user")
# Clear cache to ensure fresh data (already done above) # Clear cache to ensure fresh data (already done above)
# No need for additional cache clearing since _clearTableCache("users") was called
return User.from_dict(createdUser[0]) return User.from_dict(createdUser[0])
@ -489,10 +473,8 @@ class AppObjects:
updatedUser = User.from_dict(updatedData) updatedUser = User.from_dict(updatedData)
# Update user record # Update user record
self.db.recordModify("users", userId, updatedUser.to_dict()) self.db.recordModify(UserInDB, userId, updatedUser)
# Clear cache to ensure fresh data
self._clearTableCache("users")
# Get updated user # Get updated user
updatedUser = self.getUser(userId) updatedUser = self.getUser(userId)
@ -519,20 +501,20 @@ class AppObjects:
# Delete user auth events # Delete user auth events
events = self.db.getRecordset("auth_events", recordFilter={"userId": userId}) events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
for event in events: for event in events:
self.db.recordDelete("auth_events", event["id"]) self.db.recordDelete(AuthEvent, event["id"])
# Delete user tokens # Delete user tokens
tokens = self.db.getRecordset("tokens", recordFilter={"userId": userId}) tokens = self.db.getRecordset(Token, recordFilter={"userId": userId})
for token in tokens: for token in tokens:
self.db.recordDelete("tokens", token["id"]) self.db.recordDelete(Token, token["id"])
# Delete user connections # Delete user connections
connections = self.db.getRecordset("connections", recordFilter={"userId": userId}) connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
for conn in connections: for conn in connections:
self.db.recordDelete("connections", conn["id"]) self.db.recordDelete(UserConnection, conn["id"])
logger.info(f"All referenced data for user {userId} has been deleted") logger.info(f"All referenced data for user {userId} has been deleted")
@ -548,19 +530,17 @@ class AppObjects:
if not user: if not user:
raise ValueError(f"User {userId} not found") raise ValueError(f"User {userId} not found")
if not self._canModify("users", userId): if not self._canModify(UserInDB, userId):
raise PermissionError(f"No permission to delete user {userId}") raise PermissionError(f"No permission to delete user {userId}")
# Delete all referenced data first # Delete all referenced data first
self._deleteUserReferencedData(userId) self._deleteUserReferencedData(userId)
# Delete user record # Delete user record
success = self.db.recordDelete("users", userId) success = self.db.recordDelete(UserInDB, userId)
if not success: if not success:
raise ValueError(f"Failed to delete user {userId}") raise ValueError(f"Failed to delete user {userId}")
# Clear cache to ensure fresh data
self._clearTableCache("users")
logger.info(f"User {userId} successfully deleted") logger.info(f"User {userId} successfully deleted")
return True return True
@ -573,17 +553,17 @@ class AppObjects:
def getAllMandates(self) -> List[Mandate]: def getAllMandates(self) -> List[Mandate]:
"""Returns all mandates based on user access level.""" """Returns all mandates based on user access level."""
allMandates = self.db.getRecordset("mandates") allMandates = self.db.getRecordset(Mandate)
filteredMandates = self._uam("mandates", allMandates) filteredMandates = self._uam(Mandate, allMandates)
return [Mandate.from_dict(mandate) for mandate in filteredMandates] return [Mandate.from_dict(mandate) for mandate in filteredMandates]
def getMandate(self, mandateId: str) -> Optional[Mandate]: def getMandate(self, mandateId: str) -> Optional[Mandate]:
"""Returns a mandate by ID if user has access.""" """Returns a mandate by ID if user has access."""
mandates = self.db.getRecordset("mandates", recordFilter={"id": mandateId}) mandates = self.db.getRecordset(Mandate, recordFilter={"id": mandateId})
if not mandates: if not mandates:
return None return None
filteredMandates = self._uam("mandates", mandates) filteredMandates = self._uam(Mandate, mandates)
if not filteredMandates: if not filteredMandates:
return None return None
@ -591,7 +571,7 @@ class AppObjects:
def createMandate(self, name: str, language: str = "en") -> Mandate: def createMandate(self, name: str, language: str = "en") -> Mandate:
"""Creates a new mandate if user has permission.""" """Creates a new mandate if user has permission."""
if not self._canModify("mandates"): if not self._canModify(Mandate):
raise PermissionError("No permission to create mandates") raise PermissionError("No permission to create mandates")
# Create mandate data using model # Create mandate data using model
@ -601,12 +581,10 @@ class AppObjects:
) )
# Create mandate record # Create mandate record
createdRecord = self.db.recordCreate("mandates", mandateData.to_dict()) createdRecord = self.db.recordCreate(Mandate, mandateData)
if not createdRecord or not createdRecord.get("id"): if not createdRecord or not createdRecord.get("id"):
raise ValueError("Failed to create mandate record") raise ValueError("Failed to create mandate record")
# Clear cache to ensure fresh data
self._clearTableCache("mandates")
return Mandate.from_dict(createdRecord) return Mandate.from_dict(createdRecord)
@ -614,7 +592,7 @@ class AppObjects:
"""Updates a mandate if user has access.""" """Updates a mandate if user has access."""
try: try:
# First check if user has permission to modify mandates # First check if user has permission to modify mandates
if not self._canModify("mandates", mandateId): if not self._canModify(Mandate, mandateId):
raise PermissionError(f"No permission to update mandate {mandateId}") raise PermissionError(f"No permission to update mandate {mandateId}")
# Get mandate with access control # Get mandate with access control
@ -628,10 +606,9 @@ class AppObjects:
updatedMandate = Mandate.from_dict(updatedData) updatedMandate = Mandate.from_dict(updatedData)
# Update mandate record # Update mandate record
self.db.recordModify("mandates", mandateId, updatedMandate.to_dict()) self.db.recordModify(Mandate, mandateId, updatedMandate)
# Clear cache to ensure fresh data # Clear cache to ensure fresh data
self._clearTableCache("mandates")
# Get updated mandate # Get updated mandate
updatedMandate = self.getMandate(mandateId) updatedMandate = self.getMandate(mandateId)
@ -652,7 +629,7 @@ class AppObjects:
if not mandate: if not mandate:
return False return False
if not self._canModify("mandates", mandateId): if not self._canModify(Mandate, mandateId):
raise PermissionError(f"No permission to delete mandate {mandateId}") raise PermissionError(f"No permission to delete mandate {mandateId}")
# Check if mandate has users # Check if mandate has users
@ -661,10 +638,9 @@ class AppObjects:
raise ValueError(f"Cannot delete mandate {mandateId} with existing users") raise ValueError(f"Cannot delete mandate {mandateId} with existing users")
# Delete mandate # Delete mandate
success = self.db.recordDelete("mandates", mandateId) success = self.db.recordDelete(Mandate, mandateId)
# Clear cache to ensure fresh data # Clear cache to ensure fresh data
self._clearTableCache("mandates")
return success return success
@ -675,11 +651,11 @@ class AppObjects:
def _getInitialUser(self) -> Optional[Dict[str, Any]]: def _getInitialUser(self) -> Optional[Dict[str, Any]]:
"""Get the initial user record directly from database without access control.""" """Get the initial user record directly from database without access control."""
try: try:
initialUserId = self.db.getInitialId("users") initialUserId = self.getInitialId(UserInDB)
if not initialUserId: if not initialUserId:
return None return None
users = self.db.getRecordset("users", recordFilter={"id": initialUserId}) users = self.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
return users[0] if users else None return users[0] if users else None
except Exception as e: except Exception as e:
logger.error(f"Error getting initial user: {str(e)}") logger.error(f"Error getting initial user: {str(e)}")
@ -742,7 +718,7 @@ class AppObjects:
# If replace_existing is True, delete old access tokens for this user and authority first # If replace_existing is True, delete old access tokens for this user and authority first
if replace_existing: if replace_existing:
try: try:
old_tokens = self.db.getRecordset("tokens", recordFilter={ old_tokens = self.db.getRecordset(Token, recordFilter={
"userId": self.currentUser.id, "userId": self.currentUser.id,
"authority": token.authority, "authority": token.authority,
"connectionId": None # Ensure we only delete access tokens "connectionId": None # Ensure we only delete access tokens
@ -750,7 +726,7 @@ class AppObjects:
deleted_count = 0 deleted_count = 0
for old_token in old_tokens: for old_token in old_tokens:
if old_token["id"] != token.id: # Don't delete the new token if it already exists if old_token["id"] != token.id: # Don't delete the new token if it already exists
self.db.recordDelete("tokens", old_token["id"]) self.db.recordDelete(Token, old_token["id"])
deleted_count += 1 deleted_count += 1
logger.debug(f"Deleted old access token {old_token['id']} for user {self.currentUser.id} and authority {token.authority}") logger.debug(f"Deleted old access token {old_token['id']} for user {self.currentUser.id} and authority {token.authority}")
@ -767,10 +743,8 @@ class AppObjects:
token_dict["userId"] = self.currentUser.id token_dict["userId"] = self.currentUser.id
# Save to database # Save to database
self.db.recordCreate("tokens", token_dict) self.db.recordCreate(Token, token_dict)
# Clear cache to ensure fresh data
self._clearTableCache("tokens")
except Exception as e: except Exception as e:
logger.error(f"Error saving access token: {str(e)}") logger.error(f"Error saving access token: {str(e)}")
@ -799,13 +773,13 @@ class AppObjects:
# If replace_existing is True, delete old tokens for this connectionId first # If replace_existing is True, delete old tokens for this connectionId first
if replace_existing: if replace_existing:
try: try:
old_tokens = self.db.getRecordset("tokens", recordFilter={ old_tokens = self.db.getRecordset(Token, recordFilter={
"connectionId": token.connectionId "connectionId": token.connectionId
}) })
deleted_count = 0 deleted_count = 0
for old_token in old_tokens: for old_token in old_tokens:
if old_token["id"] != token.id: # Don't delete the new token if it already exists if old_token["id"] != token.id: # Don't delete the new token if it already exists
self.db.recordDelete("tokens", old_token["id"]) self.db.recordDelete(Token, old_token["id"])
deleted_count += 1 deleted_count += 1
logger.debug(f"Deleted old token {old_token['id']} for connectionId {token.connectionId}") logger.debug(f"Deleted old token {old_token['id']} for connectionId {token.connectionId}")
@ -822,10 +796,8 @@ class AppObjects:
token_dict["userId"] = self.currentUser.id token_dict["userId"] = self.currentUser.id
# Save to database # Save to database
self.db.recordCreate("tokens", token_dict) self.db.recordCreate(Token, token_dict)
# Clear cache to ensure fresh data
self._clearTableCache("tokens")
except Exception as e: except Exception as e:
logger.error(f"Error saving connection token: {str(e)}") logger.error(f"Error saving connection token: {str(e)}")
@ -839,7 +811,7 @@ class AppObjects:
raise ValueError("No valid user context available for token retrieval") raise ValueError("No valid user context available for token retrieval")
# Get access tokens for this user and authority (must NOT have connectionId) # Get access tokens for this user and authority (must NOT have connectionId)
tokens = self.db.getRecordset("tokens", recordFilter={ tokens = self.db.getRecordset(Token, recordFilter={
"userId": self.currentUser.id, "userId": self.currentUser.id,
"authority": authority, "authority": authority,
"connectionId": None # Ensure we only get access tokens "connectionId": None # Ensure we only get access tokens
@ -888,7 +860,7 @@ class AppObjects:
# Get token for this specific connection # Get token for this specific connection
# Query for specific connection # Query for specific connection
tokens = self.db.getRecordset("tokens", recordFilter={ tokens = self.db.getRecordset(Token, recordFilter={
"connectionId": connectionId "connectionId": connectionId
}) })
@ -899,7 +871,7 @@ class AppObjects:
logger.debug(f"getConnectionToken: Token {i}: id={token.get('id')}, expiresAt={token.get('expiresAt')}, createdAt={token.get('createdAt')}") logger.debug(f"getConnectionToken: Token {i}: id={token.get('id')}, expiresAt={token.get('expiresAt')}, createdAt={token.get('createdAt')}")
else: else:
# Debug: Check if there are any tokens at all in the database # Debug: Check if there are any tokens at all in the database
all_tokens = self.db.getRecordset("tokens", recordFilter={}) all_tokens = self.db.getRecordset(Token, recordFilter={})
logger.debug(f"getConnectionToken: No tokens found for connectionId {connectionId}. Total tokens in database: {len(all_tokens)}") logger.debug(f"getConnectionToken: No tokens found for connectionId {connectionId}. Total tokens in database: {len(all_tokens)}")
if all_tokens: if all_tokens:
logger.debug(f"getConnectionToken: Sample tokens: {[{'id': t.get('id'), 'connectionId': t.get('connectionId'), 'authority': t.get('authority')} for t in all_tokens[:3]]}") logger.debug(f"getConnectionToken: Sample tokens: {[{'id': t.get('id'), 'connectionId': t.get('connectionId'), 'authority': t.get('authority')} for t in all_tokens[:3]]}")
@ -956,7 +928,7 @@ class AppObjects:
raise ValueError("No valid user context available for token deletion") raise ValueError("No valid user context available for token deletion")
# Get access tokens to delete (must NOT have connectionId) # Get access tokens to delete (must NOT have connectionId)
tokens = self.db.getRecordset("tokens", recordFilter={ tokens = self.db.getRecordset(Token, recordFilter={
"userId": self.currentUser.id, "userId": self.currentUser.id,
"authority": authority, "authority": authority,
"connectionId": None # Ensure we only delete access tokens "connectionId": None # Ensure we only delete access tokens
@ -964,10 +936,8 @@ class AppObjects:
# Delete each token # Delete each token
for token in tokens: for token in tokens:
self.db.recordDelete("tokens", token["id"]) self.db.recordDelete(Token, token["id"])
# Clear cache to ensure fresh data
self._clearTableCache("tokens")
except Exception as e: except Exception as e:
logger.error(f"Error deleting access token: {str(e)}") logger.error(f"Error deleting access token: {str(e)}")
@ -981,16 +951,14 @@ class AppObjects:
raise ValueError("connectionId is required for deleteConnectionTokenByConnectionId") raise ValueError("connectionId is required for deleteConnectionTokenByConnectionId")
# Get connection tokens to delete # Get connection tokens to delete
tokens = self.db.getRecordset("tokens", recordFilter={ tokens = self.db.getRecordset(Token, recordFilter={
"connectionId": connectionId "connectionId": connectionId
}) })
# Delete each token # Delete each token
for token in tokens: for token in tokens:
self.db.recordDelete("tokens", token["id"]) self.db.recordDelete(Token, token["id"])
# Clear cache to ensure fresh data
self._clearTableCache("tokens")
except Exception as e: except Exception as e:
logger.error(f"Error deleting connection token for connectionId {connectionId}: {str(e)}") logger.error(f"Error deleting connection token for connectionId {connectionId}: {str(e)}")
@ -1005,17 +973,16 @@ class AppObjects:
cleaned_count = 0 cleaned_count = 0
# Get all tokens # Get all tokens
all_tokens = self.db.getRecordset("tokens", recordFilter={}) all_tokens = self.db.getRecordset(Token, recordFilter={})
for token_data in all_tokens: for token_data in all_tokens:
if token_data.get("expiresAt") and token_data.get("expiresAt") < current_time: if token_data.get("expiresAt") and token_data.get("expiresAt") < current_time:
# Token is expired, delete it # Token is expired, delete it
self.db.recordDelete("tokens", token_data["id"]) self.db.recordDelete(Token, token_data["id"])
cleaned_count += 1 cleaned_count += 1
# Clear cache to ensure fresh data # Clear cache to ensure fresh data
if cleaned_count > 0: if cleaned_count > 0:
self._clearTableCache("tokens")
logger.info(f"Cleaned up {cleaned_count} expired tokens") logger.info(f"Cleaned up {cleaned_count} expired tokens")
return cleaned_count return cleaned_count
@ -1072,16 +1039,23 @@ def getRootUser() -> User:
tempInterface = AppObjects() tempInterface = AppObjects()
# Get the initial user directly # Get the initial user directly
initialUserId = tempInterface.db.getInitialId("users") initialUserId = tempInterface.getInitialId(UserInDB)
if not initialUserId: if not initialUserId:
raise ValueError("No initial user ID found in database") raise ValueError("No initial user ID found in database")
users = tempInterface.db.getRecordset("users", recordFilter={"id": initialUserId}) users = tempInterface.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
if not users: if not users:
raise ValueError("Initial user not found in database") raise ValueError("Initial user not found in database")
logger.debug(f"Retrieved user data: {users[0]}")
# Convert to User model and return the model instance # Convert to User model and return the model instance
return User.from_dict(users[0]) user_data = users[0]
logger.debug(f"User data keys: {list(user_data.keys())}")
logger.debug(f"User id: {user_data.get('id')}")
logger.debug(f"User mandateId: {user_data.get('mandateId')}")
return User.parse_obj(user_data)
except Exception as e: except Exception as e:
logger.error(f"Error getting root user: {str(e)}") logger.error(f"Error getting root user: {str(e)}")

View file

@ -5,6 +5,7 @@ Handles user access management and permission checks.
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from modules.interfaces.interfaceAppModel import User, UserPrivilege from modules.interfaces.interfaceAppModel import User, UserPrivilege
from modules.interfaces.interfaceChatModel import ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument
class ChatAccess: class ChatAccess:
""" """
@ -23,19 +24,20 @@ class ChatAccess:
self.db = db self.db = db
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Unified user access management function that filters data based on user privileges Unified user access management function that filters data based on user privileges
and adds access control attributes. and adds access control attributes.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordset: Recordset to filter based on access rules recordset: Recordset to filter based on access rules
Returns: Returns:
Filtered recordset with access control attributes Filtered recordset with access control attributes
""" """
userPrivilege = self.currentUser.privilege userPrivilege = self.currentUser.privilege
table_name = model_class.__name__
filtered_records = [] filtered_records = []
# Apply filtering based on privilege # Apply filtering based on privilege
@ -54,32 +56,32 @@ class ChatAccess:
record_id = record.get("id") record_id = record.get("id")
# Set access control flags based on user permissions # Set access control flags based on user permissions
if table == "workflows": if table_name == "ChatWorkflow":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("workflows", record_id) record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id)
record["_hideDelete"] = not self.canModify("workflows", record_id) record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id)
elif table == "workflowMessages": elif table_name == "ChatMessage":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId")) record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId")) record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
elif table == "workflowLogs": elif table_name == "ChatLog":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId")) record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId")) record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
else: else:
# Default access control for other tables # Default access control for other tables
record["_hideView"] = False record["_hideView"] = False
record["_hideEdit"] = not self.canModify(table, record_id) record["_hideEdit"] = not self.canModify(model_class, record_id)
record["_hideDelete"] = not self.canModify(table, record_id) record["_hideDelete"] = not self.canModify(model_class, record_id)
return filtered_records return filtered_records
def canModify(self, table: str, recordId: Optional[str] = None) -> bool: def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
""" """
Checks if the current user can modify (create/update/delete) records in a table. Checks if the current user can modify (create/update/delete) records in a table.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordId: Optional record ID for specific record check recordId: Optional record ID for specific record check
Returns: Returns:
@ -94,7 +96,7 @@ class ChatAccess:
# For regular users and admins, check specific cases # For regular users and admins, check specific cases
if recordId is not None: if recordId is not None:
# Get the record to check ownership # Get the record to check ownership
records: List[Dict[str, Any]] = self.db.getRecordset(table, recordFilter={"id": recordId}) records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId})
if not records: if not records:
return False return False

View file

@ -174,6 +174,7 @@ register_model_labels(
class ChatDocument(BaseModel, ModelMixin): class ChatDocument(BaseModel, ModelMixin):
"""Data model for a chat document""" """Data model for a chat document"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
messageId: str = Field(description="Foreign key to message")
fileId: str = Field(description="Foreign key to file") fileId: str = Field(description="Foreign key to file")
# Direct file attributes (copied from file object) # Direct file attributes (copied from file object)
@ -197,6 +198,7 @@ register_model_labels(
{"en": "Chat Document", "fr": "Document de chat"}, {"en": "Chat Document", "fr": "Document de chat"},
{ {
"id": {"en": "ID", "fr": "ID"}, "id": {"en": "ID", "fr": "ID"},
"messageId": {"en": "Message ID", "fr": "ID du message"},
"fileId": {"en": "File ID", "fr": "ID du fichier"}, "fileId": {"en": "File ID", "fr": "ID du fichier"},
"roundNumber": {"en": "Round Number", "fr": "Numéro de tour"}, "roundNumber": {"en": "Round Number", "fr": "Numéro de tour"},
"taskNumber": {"en": "Task Number", "fr": "Numéro de tâche"}, "taskNumber": {"en": "Task Number", "fr": "Numéro de tâche"},
@ -400,6 +402,8 @@ register_model_labels(
class ChatStat(BaseModel, ModelMixin): class ChatStat(BaseModel, ModelMixin):
"""Data model for chat statistics - ONLY statistics, not workflow progress""" """Data model for chat statistics - ONLY statistics, not workflow progress"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
processingTime: Optional[float] = Field(None, description="Processing time in seconds") processingTime: Optional[float] = Field(None, description="Processing time in seconds")
tokenCount: Optional[int] = Field(None, description="Number of tokens processed") tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
bytesSent: Optional[int] = Field(None, description="Number of bytes sent") bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
@ -413,6 +417,8 @@ register_model_labels(
{"en": "Chat Statistics", "fr": "Statistiques de chat"}, {"en": "Chat Statistics", "fr": "Statistiques de chat"},
{ {
"id": {"en": "ID", "fr": "ID"}, "id": {"en": "ID", "fr": "ID"},
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
"messageId": {"en": "Message ID", "fr": "ID du message"},
"processingTime": {"en": "Processing Time", "fr": "Temps de traitement"}, "processingTime": {"en": "Processing Time", "fr": "Temps de traitement"},
"tokenCount": {"en": "Token Count", "fr": "Nombre de tokens"}, "tokenCount": {"en": "Token Count", "fr": "Nombre de tokens"},
"bytesSent": {"en": "Bytes Sent", "fr": "Octets envoyés"}, "bytesSent": {"en": "Bytes Sent", "fr": "Octets envoyés"},
@ -650,8 +656,8 @@ register_model_labels(
class TaskStep(BaseModel, ModelMixin): class TaskStep(BaseModel, ModelMixin):
id: str id: str
objective: str objective: str
dependencies: Optional[list[str]] = [] dependencies: Optional[list[str]] = Field(default_factory=list)
success_criteria: Optional[list[str]] = [] success_criteria: Optional[list[str]] = Field(default_factory=list)
estimated_complexity: Optional[str] = None estimated_complexity: Optional[str] = None
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
@ -733,23 +739,23 @@ class TaskContext(BaseModel, ModelMixin):
# Available resources # Available resources
available_documents: Optional[str] = "No documents available" available_documents: Optional[str] = "No documents available"
available_connections: Optional[list[str]] = [] available_connections: Optional[list[str]] = Field(default_factory=list)
# Previous execution state # Previous execution state
previous_results: Optional[list[str]] = [] previous_results: Optional[list[str]] = Field(default_factory=list)
previous_handover: Optional[TaskHandover] = None previous_handover: Optional[TaskHandover] = None
# Current execution state # Current execution state
improvements: Optional[list[str]] = [] improvements: Optional[list[str]] = Field(default_factory=list)
retry_count: Optional[int] = 0 retry_count: Optional[int] = 0
previous_action_results: Optional[list] = [] previous_action_results: Optional[list] = Field(default_factory=list)
previous_review_result: Optional[dict] = None previous_review_result: Optional[dict] = None
is_regeneration: Optional[bool] = False is_regeneration: Optional[bool] = False
# Failure analysis # Failure analysis
failure_patterns: Optional[list[str]] = [] failure_patterns: Optional[list[str]] = Field(default_factory=list)
failed_actions: Optional[list] = [] failed_actions: Optional[list] = Field(default_factory=list)
successful_actions: Optional[list] = [] successful_actions: Optional[list] = Field(default_factory=list)
# Criteria progress tracking for retries # Criteria progress tracking for retries
criteria_progress: Optional[dict] = None criteria_progress: Optional[dict] = None
@ -771,20 +777,20 @@ class TaskContext(BaseModel, ModelMixin):
class ReviewContext(BaseModel, ModelMixin): class ReviewContext(BaseModel, ModelMixin):
task_step: TaskStep task_step: TaskStep
task_actions: Optional[list] = [] task_actions: Optional[list] = Field(default_factory=list)
action_results: Optional[list] = [] action_results: Optional[list] = Field(default_factory=list)
step_result: Optional[dict] = {} step_result: Optional[dict] = Field(default_factory=dict)
workflow_id: Optional[str] = None workflow_id: Optional[str] = None
previous_results: Optional[list[str]] = [] previous_results: Optional[list[str]] = Field(default_factory=list)
class ReviewResult(BaseModel, ModelMixin): class ReviewResult(BaseModel, ModelMixin):
status: str status: str
reason: Optional[str] = None reason: Optional[str] = None
improvements: Optional[list[str]] = [] improvements: Optional[list[str]] = Field(default_factory=list)
quality_score: Optional[int] = 5 quality_score: Optional[int] = 5
missing_outputs: Optional[list[str]] = [] missing_outputs: Optional[list[str]] = Field(default_factory=list)
met_criteria: Optional[list[str]] = [] met_criteria: Optional[list[str]] = Field(default_factory=list)
unmet_criteria: Optional[list[str]] = [] unmet_criteria: Optional[list[str]] = Field(default_factory=list)
confidence: Optional[float] = 0.5 confidence: Optional[float] = 0.5
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language") userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")

File diff suppressed because it is too large Load diff

View file

@ -5,7 +5,9 @@ Handles user access management and permission checks.
import logging import logging
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from modules.interfaces.interfaceAppModel import User from modules.interfaces.interfaceAppModel import User, UserInDB
from modules.interfaces.interfaceComponentModel import Prompt, FileItem, FileData
from modules.interfaces.interfaceChatModel import ChatWorkflow, ChatMessage, ChatLog
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,19 +49,20 @@ class ComponentAccess:
return True return True
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Unified user access management function that filters data based on user privileges Unified user access management function that filters data based on user privileges
and adds access control attributes. and adds access control attributes.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordset: Recordset to filter based on access rules recordset: Recordset to filter based on access rules
Returns: Returns:
Filtered recordset with access control attributes Filtered recordset with access control attributes
""" """
userPrivilege = self.privilege userPrivilege = self.privilege
table_name = model_class.__name__
filtered_records = [] filtered_records = []
@ -73,9 +76,9 @@ class ComponentAccess:
filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId] filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId]
else: # Regular users else: # Regular users
# For prompts, users can see all prompts from their mandate # For prompts, users can see all prompts from their mandate
if table == "prompts": if table_name == "Prompt":
filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId] filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId]
elif table == "users": elif table_name == "UserInDB":
# For users table, users can only see their own record # For users table, users can only see their own record
filtered_records = [r for r in recordset if r.get("id") == self.userId] filtered_records = [r for r in recordset if r.get("id") == self.userId]
else: else:
@ -90,32 +93,32 @@ class ComponentAccess:
record_id = record.get("id") record_id = record.get("id")
# Set access control flags based on user permissions # Set access control flags based on user permissions
if table == "prompts": if table_name == "Prompt":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("prompts", record_id) record["_hideEdit"] = not self.canModify(Prompt, record_id)
record["_hideDelete"] = not self.canModify("prompts", record_id) record["_hideDelete"] = not self.canModify(Prompt, record_id)
# Add attribute-level permissions for mandateId # Add attribute-level permissions for mandateId
if "mandateId" in record: if "mandateId" in record:
record["_hideEdit_mandateId"] = not self.canModifyAttribute("prompts", "mandateId") record["_hideEdit_mandateId"] = not self.canModifyAttribute(Prompt, "mandateId")
elif table == "files": elif table_name == "FileItem":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("files", record_id) record["_hideEdit"] = not self.canModify(FileItem, record_id)
record["_hideDelete"] = not self.canModify("files", record_id) record["_hideDelete"] = not self.canModify(FileItem, record_id)
record["_hideDownload"] = not self.canModify("files", record_id) record["_hideDownload"] = not self.canModify(FileItem, record_id)
elif table == "workflows": elif table_name == "ChatWorkflow":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("workflows", record_id) record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id)
record["_hideDelete"] = not self.canModify("workflows", record_id) record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id)
elif table == "workflowMessages": elif table_name == "ChatMessage":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId")) record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId")) record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
elif table == "workflowLogs": elif table_name == "ChatLog":
record["_hideView"] = False # Everyone can view record["_hideView"] = False # Everyone can view
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId")) record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId")) record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
elif table == "users": elif table_name == "UserInDB":
# For users table, users can only modify their own connections # For users table, users can only modify their own connections
record["_hideView"] = False record["_hideView"] = False
record["_hideEdit"] = record_id != self.userId record["_hideEdit"] = record_id != self.userId
@ -128,17 +131,17 @@ class ComponentAccess:
else: else:
# Default access control for other tables # Default access control for other tables
record["_hideView"] = False record["_hideView"] = False
record["_hideEdit"] = not self.canModify(table, record_id) record["_hideEdit"] = not self.canModify(model_class, record_id)
record["_hideDelete"] = not self.canModify(table, record_id) record["_hideDelete"] = not self.canModify(model_class, record_id)
return filtered_records return filtered_records
def canModify(self, table: str, recordId: Optional[int] = None) -> bool: def canModify(self, model_class: type, recordId: Optional[int] = None) -> bool:
""" """
Checks if the current user can modify (create/update/delete) records in a table. Checks if the current user can modify (create/update/delete) records in a table.
Args: Args:
table: Name of the table model_class: Pydantic model class for the table
recordId: Optional record ID for specific record check recordId: Optional record ID for specific record check
Returns: Returns:
@ -153,14 +156,14 @@ class ComponentAccess:
# For regular users and admins, check specific cases # For regular users and admins, check specific cases
if recordId is not None: if recordId is not None:
# Get the record to check ownership # Get the record to check ownership
records: List[Dict[str, Any]] = self.db.getRecordset(table, recordFilter={"id": recordId}) records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId})
if not records: if not records:
return False return False
record = records[0] record = records[0]
# Special case for users table - users can modify their own connections # Special case for users table - users can modify their own connections
if table == "users": if model_class.__name__ == "UserInDB":
if record.get("id") == self.userId: if record.get("id") == self.userId:
return True return True
return False return False

View file

@ -14,11 +14,10 @@ from modules.interfaces.interfaceComponentAccess import ComponentAccess
from modules.interfaces.interfaceComponentModel import ( from modules.interfaces.interfaceComponentModel import (
FilePreview, Prompt, FileItem, FileData FilePreview, Prompt, FileItem, FileData
) )
from modules.interfaces.interfaceAppModel import User from modules.interfaces.interfaceAppModel import User, Mandate
# DYNAMIC PART: Connectors to the Interface # DYNAMIC PART: Connectors to the Interface
from modules.connectors.connectorDbJson import DatabaseConnector from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.connectors.connectorPool import get_connector, return_connector
# Basic Configurations # Basic Configurations
from modules.shared.configuration import APP_CONFIG from modules.shared.configuration import APP_CONFIG
@ -90,35 +89,38 @@ class ComponentObjects:
self.db.updateContext(self.userId) self.db.updateContext(self.userId)
def __del__(self): def __del__(self):
"""Cleanup method to return connector to pool.""" """Cleanup method to close database connection."""
if hasattr(self, 'db') and self.db is not None: if hasattr(self, 'db') and self.db is not None:
try: try:
return_connector(self.db) self.db.close()
except Exception as e: except Exception as e:
logger.error(f"Error returning connector to pool: {e}") logger.error(f"Error closing database connection: {e}")
logger.debug(f"User context set: userId={self.userId}") logger.debug(f"User context set: userId={self.userId}")
def _initializeDatabase(self): def _initializeDatabase(self):
"""Initializes the database connection.""" """Initializes the database connection directly."""
try: try:
# Get configuration values with defaults # Get configuration values with defaults
dbHost = APP_CONFIG.get("DB_MANAGEMENT_HOST", "_no_config_default_data") dbHost = APP_CONFIG.get("DB_MANAGEMENT_HOST", "_no_config_default_data")
dbDatabase = APP_CONFIG.get("DB_MANAGEMENT_DATABASE", "management") dbDatabase = APP_CONFIG.get("DB_MANAGEMENT_DATABASE", "management")
dbUser = APP_CONFIG.get("DB_MANAGEMENT_USER") dbUser = APP_CONFIG.get("DB_MANAGEMENT_USER")
dbPassword = APP_CONFIG.get("DB_MANAGEMENT_PASSWORD_SECRET") dbPassword = APP_CONFIG.get("DB_MANAGEMENT_PASSWORD_SECRET")
dbPort = int(APP_CONFIG.get("DB_MANAGEMENT_PORT"))
# Ensure the database directory exists # Create database connector directly
os.makedirs(dbHost, exist_ok=True) self.db = DatabaseConnector(
self.db = get_connector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,
dbPassword=dbPassword, dbPassword=dbPassword,
dbPort=dbPort,
userId=self.userId if hasattr(self, 'userId') else None userId=self.userId if hasattr(self, 'userId') else None
) )
# Initialize database system
self.db.initDbSystem()
logger.info("Database initialized successfully") logger.info("Database initialized successfully")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize database: {str(e)}") logger.error(f"Failed to initialize database: {str(e)}")
@ -142,7 +144,7 @@ class ComponentObjects:
"""Initializes standard prompts if they don't exist yet.""" """Initializes standard prompts if they don't exist yet."""
try: try:
# Check if any prompts exist # Check if any prompts exist
existingPrompts = self.db.getRecordset("prompts") existingPrompts = self.db.getRecordset(Prompt)
if existingPrompts: if existingPrompts:
logger.info("Prompts already exist, skipping initialization") logger.info("Prompts already exist, skipping initialization")
return return
@ -152,7 +154,7 @@ class ComponentObjects:
rootInterface = getRootInterface() rootInterface = getRootInterface()
# Get initial mandate ID through the root interface # Get initial mandate ID through the root interface
mandateId = rootInterface.getInitialId("mandates") mandateId = rootInterface.getInitialId(Mandate)
if not mandateId: if not mandateId:
logger.error("No initial mandate ID found") logger.error("No initial mandate ID found")
return return
@ -205,7 +207,7 @@ class ComponentObjects:
# Create prompts # Create prompts
for prompt in standardPrompts: for prompt in standardPrompts:
self.db.recordCreate("prompts", prompt.to_dict()) self.db.recordCreate(Prompt, prompt)
logger.info(f"Created standard prompt: {prompt.name}") logger.info(f"Created standard prompt: {prompt.name}")
# Restore original user context if it existed # Restore original user context if it existed
@ -228,10 +230,10 @@ class ComponentObjects:
self.access = None self.access = None
self.db.updateContext("") # Reset database context self.db.updateContext("") # Reset database context
def _uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Delegate to access control module.""" """Delegate to access control module."""
# First apply access control # First apply access control
filteredRecords = self.access.uam(table, recordset) filteredRecords = self.access.uam(model_class, recordset)
# Then filter out database-specific fields # Then filter out database-specific fields
cleanedRecords = [] cleanedRecords = []
@ -242,19 +244,16 @@ class ComponentObjects:
return cleanedRecords return cleanedRecords
def _canModify(self, table: str, recordId: Optional[str] = None) -> bool: def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
"""Delegate to access control module.""" """Delegate to access control module."""
return self.access.canModify(table, recordId) return self.access.canModify(model_class, recordId)
def _clearTableCache(self, table: str) -> None:
"""Clears the cache for a specific table to ensure fresh data."""
self.db.clearTableCache(table)
# Utilities # Utilities
def getInitialId(self, table: str) -> Optional[str]: def getInitialId(self, model_class: type) -> Optional[str]:
"""Returns the initial ID for a table.""" """Returns the initial ID for a table."""
return self.db.getInitialId(table) return self.db.getInitialId(model_class)
@ -263,8 +262,8 @@ class ComponentObjects:
def getAllPrompts(self) -> List[Prompt]: def getAllPrompts(self) -> List[Prompt]:
"""Returns prompts based on user access level.""" """Returns prompts based on user access level."""
try: try:
allPrompts = self.db.getRecordset("prompts") allPrompts = self.db.getRecordset(Prompt)
filteredPrompts = self._uam("prompts", allPrompts) filteredPrompts = self._uam(Prompt, allPrompts)
# Convert to Prompt objects # Convert to Prompt objects
return [Prompt.from_dict(prompt) for prompt in filteredPrompts] return [Prompt.from_dict(prompt) for prompt in filteredPrompts]
@ -275,25 +274,23 @@ class ComponentObjects:
def getPrompt(self, promptId: str) -> Optional[Prompt]: def getPrompt(self, promptId: str) -> Optional[Prompt]:
"""Returns a prompt by ID if user has access.""" """Returns a prompt by ID if user has access."""
prompts = self.db.getRecordset("prompts", recordFilter={"id": promptId}) prompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
if not prompts: if not prompts:
return None return None
filteredPrompts = self._uam("prompts", prompts) filteredPrompts = self._uam(Prompt, prompts)
return Prompt.from_dict(filteredPrompts[0]) if filteredPrompts else None return Prompt.from_dict(filteredPrompts[0]) if filteredPrompts else None
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]: def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
"""Creates a new prompt if user has permission.""" """Creates a new prompt if user has permission."""
if not self._canModify("prompts"): if not self._canModify(Prompt):
raise PermissionError("No permission to create prompts") raise PermissionError("No permission to create prompts")
# Create prompt record # Create prompt record
createdRecord = self.db.recordCreate("prompts", promptData) createdRecord = self.db.recordCreate(Prompt, promptData)
if not createdRecord or not createdRecord.get("id"): if not createdRecord or not createdRecord.get("id"):
raise ValueError("Failed to create prompt record") raise ValueError("Failed to create prompt record")
# Clear cache to ensure fresh data
self._clearTableCache("prompts")
return createdRecord return createdRecord
@ -306,10 +303,9 @@ class ComponentObjects:
raise ValueError(f"Prompt {promptId} not found") raise ValueError(f"Prompt {promptId} not found")
# Update prompt record directly with the update data # Update prompt record directly with the update data
self.db.recordModify("prompts", promptId, updateData) self.db.recordModify(Prompt, promptId, updateData)
# Clear cache to ensure fresh data # Clear cache to ensure fresh data
self._clearTableCache("prompts")
# Get updated prompt # Get updated prompt
updatedPrompt = self.getPrompt(promptId) updatedPrompt = self.getPrompt(promptId)
@ -329,14 +325,12 @@ class ComponentObjects:
if not prompt: if not prompt:
return False return False
if not self._canModify("prompts", promptId): if not self._canModify(Prompt, promptId):
raise PermissionError(f"No permission to delete prompt {promptId}") raise PermissionError(f"No permission to delete prompt {promptId}")
# Delete prompt # Delete prompt
success = self.db.recordDelete("prompts", promptId) success = self.db.recordDelete(Prompt, promptId)
# Clear cache to ensure fresh data
self._clearTableCache("prompts")
return success return success
@ -347,12 +341,12 @@ class ComponentObjects:
If fileName is provided, also checks for exact name+hash match. If fileName is provided, also checks for exact name+hash match.
Only returns files the current user has access to.""" Only returns files the current user has access to."""
# First get all files with the hash # First get all files with the hash
allFilesWithHash = self.db.getRecordset("files", recordFilter={ allFilesWithHash = self.db.getRecordset(FileItem, recordFilter={
"fileHash": fileHash "fileHash": fileHash
}) })
# Filter by user access using UAM # Filter by user access using UAM
accessibleFiles = self._uam("files", allFilesWithHash) accessibleFiles = self._uam(FileItem, allFilesWithHash)
if not accessibleFiles: if not accessibleFiles:
return None return None
@ -468,8 +462,8 @@ class ComponentObjects:
def getAllFiles(self) -> List[FileItem]: def getAllFiles(self) -> List[FileItem]:
"""Returns files based on user access level.""" """Returns files based on user access level."""
allFiles = self.db.getRecordset("files") allFiles = self.db.getRecordset(FileItem)
filteredFiles = self._uam("files", allFiles) filteredFiles = self._uam(FileItem, allFiles)
# Convert database records to FileItem instances # Convert database records to FileItem instances
fileItems = [] fileItems = []
@ -502,11 +496,11 @@ class ComponentObjects:
def getFile(self, fileId: str) -> Optional[FileItem]: def getFile(self, fileId: str) -> Optional[FileItem]:
"""Returns a file by ID if user has access.""" """Returns a file by ID if user has access."""
files = self.db.getRecordset("files", recordFilter={"id": fileId}) files = self.db.getRecordset(FileItem, recordFilter={"id": fileId})
if not files: if not files:
return None return None
filteredFiles = self._uam("files", files) filteredFiles = self._uam(FileItem, files)
if not filteredFiles: if not filteredFiles:
return None return None
@ -534,7 +528,7 @@ class ComponentObjects:
def _isfileNameUnique(self, fileName: str, excludeFileId: Optional[str] = None) -> bool: def _isfileNameUnique(self, fileName: str, excludeFileId: Optional[str] = None) -> bool:
"""Checks if a fileName is unique for the current user.""" """Checks if a fileName is unique for the current user."""
# Get all files for current user # Get all files for current user
files = self.db.getRecordset("files", recordFilter={ files = self.db.getRecordset(FileItem, recordFilter={
"_createdBy": self.currentUser.id "_createdBy": self.currentUser.id
}) })
@ -566,7 +560,7 @@ class ComponentObjects:
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem: def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content.""" """Creates a new file entry if user has permission. Computes fileHash and fileSize from content."""
import hashlib import hashlib
if not self._canModify("files"): if not self._canModify(FileItem):
raise PermissionError("No permission to create files") raise PermissionError("No permission to create files")
# Ensure fileName is unique # Ensure fileName is unique
@ -589,10 +583,8 @@ class ComponentObjects:
) )
# Store in database # Store in database
self.db.recordCreate("files", fileItem.to_dict()) self.db.recordCreate(FileItem, fileItem)
# Clear cache to ensure fresh data
self._clearTableCache("files")
return fileItem return fileItem
@ -603,7 +595,7 @@ class ComponentObjects:
if not file: if not file:
raise FileNotFoundError(f"File with ID {fileId} not found") raise FileNotFoundError(f"File with ID {fileId} not found")
if not self._canModify("files", fileId): if not self._canModify(FileItem, fileId):
raise PermissionError(f"No permission to update file {fileId}") raise PermissionError(f"No permission to update file {fileId}")
# If fileName is being updated, ensure it's unique # If fileName is being updated, ensure it's unique
@ -611,10 +603,8 @@ class ComponentObjects:
updateData["fileName"] = self._generateUniquefileName(updateData["fileName"], fileId) updateData["fileName"] = self._generateUniquefileName(updateData["fileName"], fileId)
# Update file # Update file
success = self.db.recordModify("files", fileId, updateData) success = self.db.recordModify(FileItem, fileId, updateData)
# Clear cache to ensure fresh data
self._clearTableCache("files")
return success return success
@ -627,30 +617,29 @@ class ComponentObjects:
if not file: if not file:
raise FileNotFoundError(f"File with ID {fileId} not found") raise FileNotFoundError(f"File with ID {fileId} not found")
if not self._canModify("files", fileId): if not self._canModify(FileItem, fileId):
raise PermissionError(f"No permission to delete file {fileId}") raise PermissionError(f"No permission to delete file {fileId}")
# Check for other references to this file (by hash) # Check for other references to this file (by hash)
fileHash = file.fileHash fileHash = file.fileHash
if fileHash: if fileHash:
otherReferences = [f for f in self.db.getRecordset("files", recordFilter={"fileHash": fileHash}) otherReferences = [f for f in self.db.getRecordset(FileItem, recordFilter={"fileHash": fileHash})
if f["id"] != fileId] if f["id"] != fileId]
# Only delete associated fileData if no other references exist # Only delete associated fileData if no other references exist
if not otherReferences: if not otherReferences:
try: try:
fileDataEntries = self.db.getRecordset("fileData", recordFilter={"id": fileId}) fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
if fileDataEntries: if fileDataEntries:
self.db.recordDelete("fileData", fileId) self.db.recordDelete(FileData, fileId)
logger.debug(f"FileData for file {fileId} deleted") logger.debug(f"FileData for file {fileId} deleted")
except Exception as e: except Exception as e:
logger.warning(f"Error deleting FileData for file {fileId}: {str(e)}") logger.warning(f"Error deleting FileData for file {fileId}: {str(e)}")
# Delete the FileItem entry # Delete the FileItem entry
success = self.db.recordDelete("files", fileId) success = self.db.recordDelete(FileItem, fileId)
# Clear cache to ensure fresh data # Clear cache to ensure fresh data
self._clearTableCache("files")
return success return success
@ -709,10 +698,9 @@ class ComponentObjects:
"base64Encoded": base64Encoded "base64Encoded": base64Encoded
} }
self.db.recordCreate("fileData", fileDataObj) self.db.recordCreate(FileData, fileDataObj)
# Clear cache to ensure fresh data # Clear cache to ensure fresh data
self._clearTableCache("fileData")
logger.debug(f"Successfully stored data for file {fileId} (base64Encoded: {base64Encoded})") logger.debug(f"Successfully stored data for file {fileId} (base64Encoded: {base64Encoded})")
return True return True
@ -730,7 +718,7 @@ class ComponentObjects:
import base64 import base64
fileDataEntries = self.db.getRecordset("fileData", recordFilter={"id": fileId}) fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
if not fileDataEntries: if not fileDataEntries:
logger.warning(f"No data found for file ID {fileId}") logger.warning(f"No data found for file ID {fileId}")
return None return None
@ -830,7 +818,7 @@ class ComponentObjects:
"""Saves an uploaded file if user has permission.""" """Saves an uploaded file if user has permission."""
try: try:
# Check file creation permission # Check file creation permission
if not self._canModify("files"): if not self._canModify(FileItem):
raise PermissionError("No permission to upload files") raise PermissionError("No permission to upload files")
logger.debug(f"Starting upload process for file: {fileName}") logger.debug(f"Starting upload process for file: {fileName}")

View file

@ -39,7 +39,7 @@ def get_token_status_for_connection(interface, connection_id: str) -> tuple[str,
try: try:
# Query tokens table for the latest token for this connection # Query tokens table for the latest token for this connection
tokens = interface.db.getRecordset( tokens = interface.db.getRecordset(
table="tokens", Token,
recordFilter={"connectionId": connection_id} recordFilter={"connectionId": connection_id}
) )
@ -93,9 +93,6 @@ async def get_connections(
try: try:
interface = getInterface(currentUser) interface = getInterface(currentUser)
# Clear connections cache to ensure fresh data
interface.db.clearTableCache("connections")
# SECURITY FIX: All users (including admins) can only see their own connections # SECURITY FIX: All users (including admins) can only see their own connections
# This prevents admin from seeing other users' connections and causing confusion # This prevents admin from seeing other users' connections and causing confusion
connections = interface.getUserConnections(currentUser.id) connections = interface.getUserConnections(currentUser.id)
@ -179,10 +176,8 @@ async def create_connection(
) )
# Save connection record - models now handle timestamp serialization automatically # Save connection record - models now handle timestamp serialization automatically
interface.db.recordModify("connections", connection.id, connection.to_dict()) interface.db.recordModify(UserConnection, connection.id, connection.to_dict())
# Clear cache to ensure fresh data
interface.db.clearTableCache("connections")
return connection return connection
@ -235,10 +230,8 @@ async def update_connection(
connection.lastChecked = get_utc_timestamp() connection.lastChecked = get_utc_timestamp()
# Update connection - models now handle timestamp serialization automatically # Update connection - models now handle timestamp serialization automatically
interface.db.recordModify("connections", connectionId, connection.to_dict()) interface.db.recordModify(UserConnection, connectionId, connection.to_dict())
# Clear cache to ensure fresh data
interface.db.clearTableCache("connections")
# Get token status for the updated connection # Get token status for the updated connection
token_status, token_expires_at = get_token_status_for_connection(interface, connectionId) token_status, token_expires_at = get_token_status_for_connection(interface, connectionId)
@ -372,10 +365,8 @@ async def disconnect_service(
connection.lastChecked = get_utc_timestamp() connection.lastChecked = get_utc_timestamp()
# Update connection record - models now handle timestamp serialization automatically # Update connection record - models now handle timestamp serialization automatically
interface.db.recordModify("connections", connectionId, connection.to_dict()) interface.db.recordModify(UserConnection, connectionId, connection.to_dict())
# Clear cache to ensure fresh data
interface.db.clearTableCache("connections")
return {"message": "Service disconnected successfully"} return {"message": "Service disconnected successfully"}

View file

@ -173,7 +173,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
rootInterface = getRootInterface() rootInterface = getRootInterface()
# Prefer connection flow reuse; fallback to user access token # Prefer connection flow reuse; fallback to user access token
if connection_id: if connection_id:
existing_tokens = rootInterface.db.getRecordset("tokens", recordFilter={ from modules.interfaces.interfaceAppModel import Token
existing_tokens = rootInterface.db.getRecordset(Token, recordFilter={
"connectionId": connection_id, "connectionId": connection_id,
"authority": AuthAuthority.GOOGLE "authority": AuthAuthority.GOOGLE
}) })
@ -182,7 +183,7 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
existing_tokens.sort(key=lambda x: x.get("createdAt", 0), reverse=True) existing_tokens.sort(key=lambda x: x.get("createdAt", 0), reverse=True)
token_response["refresh_token"] = existing_tokens[0].get("tokenRefresh", "") token_response["refresh_token"] = existing_tokens[0].get("tokenRefresh", "")
if not token_response.get("refresh_token") and user_id: if not token_response.get("refresh_token") and user_id:
existing_access_tokens = rootInterface.db.getRecordset("tokens", recordFilter={ existing_access_tokens = rootInterface.db.getRecordset(Token, recordFilter={
"userId": user_id, "userId": user_id,
"connectionId": None, "connectionId": None,
"authority": AuthAuthority.GOOGLE "authority": AuthAuthority.GOOGLE
@ -358,10 +359,9 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
connection.externalEmail = user_info.get("email") connection.externalEmail = user_info.get("email")
# Update connection record directly # Update connection record directly
rootInterface.db.recordModify("connections", connection_id, connection.to_dict()) from modules.interfaces.interfaceAppModel import UserConnection
rootInterface.db.recordModify(UserConnection, connection_id, connection.to_dict())
# Clear cache to ensure fresh data
rootInterface.db.clearTableCache("connections")
# Save token # Save token
token = Token( token = Token(
@ -543,7 +543,7 @@ async def refresh_token(
google_connection.status = ConnectionStatus.ACTIVE google_connection.status = ConnectionStatus.ACTIVE
# Save updated connection # Save updated connection
appInterface.db.recordModify("connections", google_connection.id, google_connection.to_dict()) appInterface.db.recordModify(UserConnection, google_connection.id, google_connection.to_dict())
# Calculate time until expiration # Calculate time until expiration
current_time = get_utc_timestamp() current_time = get_utc_timestamp()

View file

@ -52,7 +52,8 @@ async def login(
rootInterface = getRootInterface() rootInterface = getRootInterface()
# Get default mandate ID # Get default mandate ID
defaultMandateId = rootInterface.getInitialId("mandates") from modules.interfaces.interfaceAppModel import Mandate
defaultMandateId = rootInterface.getInitialId(Mandate)
if not defaultMandateId: if not defaultMandateId:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -146,7 +147,8 @@ async def register_user(
appInterface = getRootInterface() appInterface = getRootInterface()
# Get default mandate ID # Get default mandate ID
defaultMandateId = appInterface.getInitialId("mandates") from modules.interfaces.interfaceAppModel import Mandate
defaultMandateId = appInterface.getInitialId(Mandate)
if not defaultMandateId: if not defaultMandateId:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View file

@ -309,10 +309,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
connection.externalEmail = user_info.get("mail") connection.externalEmail = user_info.get("mail")
# Update connection record directly # Update connection record directly
rootInterface.db.recordModify("connections", connection_id, connection.to_dict()) rootInterface.db.recordModify(UserConnection, connection_id, connection.to_dict())
# Clear cache to ensure fresh data
rootInterface.db.clearTableCache("connections")
# Save token # Save token
@ -524,7 +522,7 @@ async def refresh_token(
msft_connection.status = ConnectionStatus.ACTIVE msft_connection.status = ConnectionStatus.ACTIVE
# Save updated connection # Save updated connection
appInterface.db.recordModify("connections", msft_connection.id, msft_connection.to_dict()) appInterface.db.recordModify(UserConnection, msft_connection.id, msft_connection.to_dict())
# Calculate time until expiration # Calculate time until expiration
current_time = get_utc_timestamp() current_time = get_utc_timestamp()

View file

@ -57,7 +57,7 @@ async def get_workflows(
"""Get all workflows for the current user.""" """Get all workflows for the current user."""
try: try:
appInterface = getInterface(currentUser) appInterface = getInterface(currentUser)
workflows_data = appInterface.getAllWorkflows() workflows_data = appInterface.getWorkflows()
# Convert raw dictionaries to ChatWorkflow objects # Convert raw dictionaries to ChatWorkflow objects
workflows = [] workflows = []
@ -136,7 +136,7 @@ async def update_workflow(
workflowInterface = getInterface(currentUser) workflowInterface = getInterface(currentUser)
# Get raw workflow data from database to check permissions # Get raw workflow data from database to check permissions
workflows = workflowInterface.db.getRecordset("workflows", recordFilter={"id": workflowId}) workflows = workflowInterface.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
if not workflows: if not workflows:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -225,7 +225,7 @@ async def get_workflow_logs(
) )
# Get all logs # Get all logs
allLogs = interfaceChat.getWorkflowLogs(workflowId) allLogs = interfaceChat.getLogs(workflowId)
# Apply selective data transfer if logId is provided # Apply selective data transfer if logId is provided
if logId: if logId:
@ -268,7 +268,7 @@ async def get_workflow_messages(
) )
# Get all messages # Get all messages
allMessages = interfaceChat.getWorkflowMessages(workflowId) allMessages = interfaceChat.getMessages(workflowId)
# Apply selective data transfer if messageId is provided # Apply selective data transfer if messageId is provided
if messageId: if messageId:
@ -356,7 +356,7 @@ async def delete_workflow(
interfaceChat = getServiceChat(currentUser) interfaceChat = getServiceChat(currentUser)
# Get raw workflow data from database to check permissions # Get raw workflow data from database to check permissions
workflows = interfaceChat.db.getRecordset("workflows", recordFilter={"id": workflowId}) workflows = interfaceChat.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
if not workflows: if not workflows:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -419,7 +419,7 @@ async def delete_workflow_message(
) )
# Delete the message # Delete the message
success = interfaceChat.deleteWorkflowMessage(workflowId, messageId) success = interfaceChat.deleteMessage(workflowId, messageId)
if not success: if not success:
raise HTTPException( raise HTTPException(

View file

@ -76,12 +76,12 @@ class WorkflowManager:
"taskProgress": "pending", "taskProgress": "pending",
"actionProgress": "pending" "actionProgress": "pending"
} }
message = self.chatInterface.createWorkflowMessage(stopped_message) message = self.chatInterface.createMessage(stopped_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
# Add log entry # Add log entry
self.chatInterface.createWorkflowLog({ self.chatInterface.createLog({
"workflowId": workflow.id, "workflowId": workflow.id,
"message": "Workflow stopped by user", "message": "Workflow stopped by user",
"type": "warning", "type": "warning",
@ -120,12 +120,12 @@ class WorkflowManager:
"taskProgress": "fail", "taskProgress": "fail",
"actionProgress": "fail" "actionProgress": "fail"
} }
message = self.chatInterface.createWorkflowMessage(error_message) message = self.chatInterface.createMessage(error_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
# Add error log entry # Add error log entry
self.chatInterface.createWorkflowLog({ self.chatInterface.createLog({
"workflowId": workflow.id, "workflowId": workflow.id,
"message": f"Workflow failed: {str(e)}", "message": f"Workflow failed: {str(e)}",
"type": "error", "type": "error",
@ -165,16 +165,19 @@ class WorkflowManager:
"actionProgress": "pending" "actionProgress": "pending"
} }
# Add documents if any # Create message first to get messageId
if userInput.listFileId: message = self.chatInterface.createMessage(messageData)
# Process file IDs and add to message data
documents = await self.chatManager.service.processFileIds(userInput.listFileId)
messageData["documents"] = documents
# Create message using interface
message = self.chatInterface.createWorkflowMessage(messageData)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
# Add documents if any, now with messageId
if userInput.listFileId:
# Process file IDs and add to message data
documents = await self.chatManager.service.processFileIds(userInput.listFileId, message.id)
message.documents = documents
# Update the message with documents in database
self.chatInterface.updateMessage(message.id, {"documents": [doc.to_dict() for doc in documents]})
return message return message
else: else:
raise Exception("Failed to create first message") raise Exception("Failed to create first message")
@ -241,7 +244,7 @@ class WorkflowManager:
} }
# Create message using interface # Create message using interface
message = self.chatInterface.createWorkflowMessage(messageData) message = self.chatInterface.createMessage(messageData)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
@ -256,7 +259,7 @@ class WorkflowManager:
}) })
# Add completion log entry # Add completion log entry
self.chatInterface.createWorkflowLog({ self.chatInterface.createLog({
"workflowId": workflow.id, "workflowId": workflow.id,
"message": "Workflow completed", "message": "Workflow completed",
"type": "success", "type": "success",
@ -294,7 +297,7 @@ class WorkflowManager:
"taskProgress": "stopped", "taskProgress": "stopped",
"actionProgress": "stopped" "actionProgress": "stopped"
} }
message = self.chatInterface.createWorkflowMessage(stopped_message) message = self.chatInterface.createMessage(stopped_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
@ -326,7 +329,7 @@ class WorkflowManager:
"taskProgress": "stopped", "taskProgress": "stopped",
"actionProgress": "stopped" "actionProgress": "stopped"
} }
message = self.chatInterface.createWorkflowMessage(stopped_message) message = self.chatInterface.createMessage(stopped_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
@ -341,7 +344,7 @@ class WorkflowManager:
}) })
# Add stopped log entry # Add stopped log entry
self.chatInterface.createWorkflowLog({ self.chatInterface.createLog({
"workflowId": workflow.id, "workflowId": workflow.id,
"message": "Workflow stopped by user", "message": "Workflow stopped by user",
"type": "warning", "type": "warning",
@ -368,7 +371,7 @@ class WorkflowManager:
"taskProgress": "fail", "taskProgress": "fail",
"actionProgress": "fail" "actionProgress": "fail"
} }
message = self.chatInterface.createWorkflowMessage(error_message) message = self.chatInterface.createMessage(error_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
@ -383,7 +386,7 @@ class WorkflowManager:
}) })
# Add failed log entry # Add failed log entry
self.chatInterface.createWorkflowLog({ self.chatInterface.createLog({
"workflowId": workflow.id, "workflowId": workflow.id,
"message": f"Workflow failed: {workflow_result.error or 'Unknown error'}", "message": f"Workflow failed: {workflow_result.error or 'Unknown error'}",
"type": "error", "type": "error",
@ -411,7 +414,7 @@ class WorkflowManager:
"actionProgress": "success" "actionProgress": "success"
} }
message = self.chatInterface.createWorkflowMessage(summary_message) message = self.chatInterface.createMessage(summary_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)
@ -426,7 +429,7 @@ class WorkflowManager:
}) })
# Add completion log entry # Add completion log entry
self.chatInterface.createWorkflowLog({ self.chatInterface.createLog({
"workflowId": workflow.id, "workflowId": workflow.id,
"message": "Workflow completed successfully", "message": "Workflow completed successfully",
"type": "success", "type": "success",
@ -454,7 +457,7 @@ class WorkflowManager:
"taskProgress": "fail", "taskProgress": "fail",
"actionProgress": "fail" "actionProgress": "fail"
} }
message = self.chatInterface.createWorkflowMessage(error_message) message = self.chatInterface.createMessage(error_message)
if message: if message:
workflow.messages.append(message) workflow.messages.append(message)

View file

@ -4,7 +4,8 @@ TODO
# System # System
- database - database
- db initialization as separate function to create root mandate, then sysadmin with hashed passwords --> using the connector according to env configuration - db initialization as separate function to create root mandate, then sysadmin with hashed passwords --> using the connector according to env configuration
- config page for: db reset - settings: UI page for: db new (delete if exists and init), then to add mandate root and sysadmin, log download --> in the api to add connector settings with the according endpoints
- access model as matrix, not as code --> to have view, add, update, delete with the rights on level table and attribute for all, my (created by me), my mandate (mandate I am in), none (no access)
- document handling centralized - document handling centralized
- ai handling centralized - ai handling centralized
- neutralizer to activate AND put back placeholders to the returned data - neutralizer to activate AND put back placeholders to the returned data

View file

@ -1,8 +0,0 @@
New features
- Limiter and tracking of ip adress access
- Sessions improved
- user and connection consequently separated
- seamless local and external authorities integration
- audit trail
- nda disclaimer in login window
- CSRF Tokens included in forms

1
query Normal file
View file

@ -0,0 +1 @@
postgresql

View file

@ -91,3 +91,6 @@ linkify-it-py>=1.0.0
mdit-py-plugins>=0.3.0 mdit-py-plugins>=0.3.0
pyviz-comms>=2.0.0 pyviz-comms>=2.0.0
xyzservices>=2021.09.1 xyzservices>=2021.09.1
# PostgreSQL connector dependencies
psycopg2-binary==2.9.9

View file

@ -1,237 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify concurrency improvements in DatabaseConnector.
This script simulates multiple users accessing the database simultaneously.
"""
import os
import sys
import time
import threading
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import tempfile
import shutil
# Add the gateway directory to the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from modules.connectors.connectorDbJson import DatabaseConnector
from modules.connectors.connectorPool import get_connector, return_connector
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_concurrent_record_operations():
"""Test concurrent record creation, modification, and deletion."""
# Create temporary database directory
temp_dir = tempfile.mkdtemp()
db_host = temp_dir
db_database = "test_db"
try:
logger.info("Starting concurrency test...")
def user_operation(user_id: int, operation_count: int = 10):
"""Simulate a user performing database operations."""
try:
# Get a dedicated connector for this user
db = get_connector(
dbHost=db_host,
dbDatabase=db_database,
userId=f"user_{user_id}"
)
results = []
for i in range(operation_count):
# Create a record
record = {
"id": f"user_{user_id}_record_{i}",
"data": f"User {user_id} data {i}",
"timestamp": time.time()
}
# Create record
created = db.recordCreate("test_table", record)
results.append(f"Created: {created['id']}")
# Modify record
record["data"] = f"Modified by user {user_id} - {i}"
modified = db.recordModify("test_table", record["id"], record)
results.append(f"Modified: {modified['id']}")
# Small delay to increase chance of race conditions
time.sleep(0.001)
# Return connector to pool
return_connector(db)
return results
except Exception as e:
logger.error(f"User {user_id} error: {e}")
return [f"Error: {e}"]
# Test with multiple concurrent users
num_users = 20
operations_per_user = 5
logger.info(f"Testing with {num_users} users, {operations_per_user} operations each")
start_time = time.time()
with ThreadPoolExecutor(max_workers=num_users) as executor:
# Submit all user operations
futures = [
executor.submit(user_operation, user_id, operations_per_user)
for user_id in range(num_users)
]
# Collect results
all_results = []
for future in as_completed(futures):
try:
result = future.result()
all_results.extend(result)
except Exception as e:
logger.error(f"Future error: {e}")
end_time = time.time()
# Verify data integrity
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
# Check that all records exist and are consistent
all_records = db.getRecordset("test_table")
expected_count = num_users * operations_per_user
logger.info(f"Expected records: {expected_count}")
logger.info(f"Actual records: {len(all_records)}")
logger.info(f"Test completed in {end_time - start_time:.2f} seconds")
# Check for data consistency
record_ids = set(record["id"] for record in all_records)
expected_ids = set(f"user_{user_id}_record_{i}" for user_id in range(num_users) for i in range(operations_per_user))
missing_ids = expected_ids - record_ids
extra_ids = record_ids - expected_ids
if missing_ids:
logger.error(f"Missing records: {missing_ids}")
if extra_ids:
logger.error(f"Extra records: {extra_ids}")
# Check for data corruption (records with wrong user data)
corrupted_records = []
for record in all_records:
record_id = record["id"]
user_id = int(record_id.split("_")[1])
if f"Modified by user {user_id}" not in record.get("data", ""):
corrupted_records.append(record_id)
if corrupted_records:
logger.error(f"Corrupted records: {corrupted_records}")
success = len(missing_ids) == 0 and len(extra_ids) == 0 and len(corrupted_records) == 0
if success:
logger.info("✅ Concurrency test PASSED - No data corruption detected")
else:
logger.error("❌ Concurrency test FAILED - Data corruption detected")
return success
finally:
# Cleanup
try:
shutil.rmtree(temp_dir)
logger.info("Cleaned up temporary directory")
except Exception as e:
logger.error(f"Error cleaning up: {e}")
def test_metadata_consistency():
"""Test that metadata operations are atomic."""
temp_dir = tempfile.mkdtemp()
db_host = temp_dir
db_database = "test_metadata"
try:
logger.info("Testing metadata consistency...")
def concurrent_metadata_operations(user_id: int):
"""Perform concurrent metadata operations."""
db = get_connector(
dbHost=db_host,
dbDatabase=db_database,
userId=f"user_{user_id}"
)
try:
# Create multiple records rapidly
for i in range(10):
record = {
"id": f"user_{user_id}_meta_{i}",
"data": f"Metadata test {user_id}-{i}"
}
db.recordCreate("metadata_test", record)
time.sleep(0.001) # Small delay
return True
except Exception as e:
logger.error(f"Metadata test error for user {user_id}: {e}")
return False
finally:
return_connector(db)
# Run concurrent metadata operations
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(concurrent_metadata_operations, i) for i in range(10)]
results = [future.result() for future in as_completed(futures)]
# Verify metadata consistency
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
records = db.getRecordset("metadata_test")
# Check that metadata is consistent
metadata = db._loadTableMetadata("metadata_test")
expected_count = len(records)
actual_count = len(metadata["recordIds"])
logger.info(f"Expected record count: {expected_count}")
logger.info(f"Metadata record count: {actual_count}")
success = expected_count == actual_count
if success:
logger.info("✅ Metadata consistency test PASSED")
else:
logger.error("❌ Metadata consistency test FAILED")
return success
finally:
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.error(f"Error cleaning up: {e}")
if __name__ == "__main__":
logger.info("Starting concurrency tests...")
# Test 1: Concurrent record operations
test1_passed = test_concurrent_record_operations()
# Test 2: Metadata consistency
test2_passed = test_metadata_consistency()
# Overall result
if test1_passed and test2_passed:
logger.info("🎉 All concurrency tests PASSED!")
sys.exit(0)
else:
logger.error("💥 Some concurrency tests FAILED!")
sys.exit(1)

View file

@ -1,108 +0,0 @@
"""Tests for Tavliy web search."""
import pytest
import logging
from modules.interfaces.interfaceChatModel import ActionResult
from gateway.modules.interfaces.interfaceWebModel import (
WebSearchRequest,
WebCrawlRequest,
WebScrapeRequest,
)
from gateway.modules.connectors.connectorWebTavily import ConnectorTavily
logger = logging.getLogger(__name__)
@pytest.mark.asyncio
@pytest.mark.expensive
async def test_tavily_connector_search_test_live_api():
logger.info("Testing Tavliy connector search with live API calls")
# Test request
request = WebSearchRequest(query="How old is the Earth?", max_results=5)
# Tavily instance
connectorWebTavily = await ConnectorTavily.create()
# Search test
action_result = await connectorWebTavily.search_urls(request=request)
# Check results
assert isinstance(action_result, ActionResult)
logger.info("=" * 20)
logger.info(f"Action result success status: {action_result.success}")
logger.info(f"Action result error: {action_result.error}")
logger.info(f"Action result label: {action_result.resultLabel}")
logger.info("Documents:")
for doc in action_result.documents:
logger.info("-" * 10)
logger.info(f" - Document Name: {doc.documentName}")
logger.info(f" - Document Mime Type: {doc.mimeType}")
logger.info(f" - Document Data: {doc.documentData}")
@pytest.mark.asyncio
@pytest.mark.expensive
async def test_tavily_connector_crawl_test_live_api():
logger.info("Testing Tavily connector crawl with live API calls")
# Test request
urls = [
"https://en.wikipedia.org/wiki/Earth",
"https://valueon.ch",
]
request = WebCrawlRequest(urls=urls)
# Tavily instance
connectorWebTavily = await ConnectorTavily.create()
# Crawl test
action_result = await connectorWebTavily.crawl_urls(request=request)
# Check results
assert isinstance(action_result, ActionResult)
logger.info("=" * 20)
logger.info(f"Action result success status: {action_result.success}")
logger.info(f"Action result error: {action_result.error}")
logger.info(f"Action result label: {action_result.resultLabel}")
logger.info("Documents:")
for doc in action_result.documents:
logger.info("-" * 10)
logger.info(f" - Document Name: {doc.documentName}")
logger.info(f" - Document Mime Type: {doc.mimeType}")
logger.info(f" - Document Data: {doc.documentData}")
@pytest.mark.asyncio
@pytest.mark.expensive
async def test_tavily_connector_scrape_test_live_api():
logger.info("Testing Tavily connector scrape with live API calls")
# Test request with query
request = WebScrapeRequest(query="How old is the Earth?", max_results=3)
# Tavily instance
connectorWebTavily = await ConnectorTavily.create()
# Scrape test
action_result = await connectorWebTavily.scrape(request=request)
# Check results
assert isinstance(action_result, ActionResult)
logger.info("=" * 20)
logger.info(f"Action result success status: {action_result.success}")
logger.info(f"Action result error: {action_result.error}")
logger.info(f"Action result label: {action_result.resultLabel}")
logger.info("Documents:")
for doc in action_result.documents:
logger.info("-" * 10)
logger.info(f" - Document Name: {doc.documentName}")
logger.info(f" - Document Mime Type: {doc.mimeType}")
logger.info(f" - Document Data: {doc.documentData}")