database attached
This commit is contained in:
parent
98a4323b36
commit
8fbbd35055
70 changed files with 1923 additions and 1259 deletions
7
app.py
7
app.py
|
|
@ -63,10 +63,11 @@ def initLogging():
|
|||
class EmojiFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
if isinstance(record.msg, str):
|
||||
# Remove emojis and other Unicode characters that might cause encoding issues
|
||||
# Remove only emojis, preserve other Unicode characters like quotes
|
||||
import re
|
||||
# Remove emojis and other Unicode symbols
|
||||
record.msg = re.sub(r'[^\x00-\x7F]+', '[EMOJI]', record.msg)
|
||||
import unicodedata
|
||||
# Remove emoji characters specifically
|
||||
record.msg = ''.join(char for char in record.msg if unicodedata.category(char) != 'So' or not (0x1F600 <= ord(char) <= 0x1F64F or 0x1F300 <= ord(char) <= 0x1F5FF or 0x1F680 <= ord(char) <= 0x1F6FF or 0x1F1E0 <= ord(char) <= 0x1F1FF or 0x2600 <= ord(char) <= 0x26FF or 0x2700 <= ord(char) <= 0x27BF))
|
||||
return True
|
||||
|
||||
# Configure handlers based on config
|
||||
|
|
|
|||
69
env_dev.env
Normal file
69
env_dev.env
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# Development Environment Configuration
|
||||
|
||||
# System Configuration
|
||||
APP_ENV_TYPE = dev
|
||||
APP_ENV_LABEL = Development Instance Patrick
|
||||
APP_API_URL = http://localhost:8000
|
||||
|
||||
# Database Configuration for Application
|
||||
# JSON File Storage (current)
|
||||
# DB_APP_HOST=D:/Temp/_powerondb
|
||||
# DB_APP_DATABASE=app
|
||||
# DB_APP_USER=dev_user
|
||||
# DB_APP_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_APP_HOST=localhost
|
||||
DB_APP_DATABASE=poweron_app_dev
|
||||
DB_APP_USER=poweron_dev
|
||||
DB_APP_PASSWORD_SECRET=dev_password
|
||||
DB_APP_PORT=5432
|
||||
|
||||
# Database Configuration Chat
|
||||
# JSON File Storage (current)
|
||||
# DB_CHAT_HOST=D:/Temp/_powerondb
|
||||
# DB_CHAT_DATABASE=chat
|
||||
# DB_CHAT_USER=dev_user
|
||||
# DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_CHAT_HOST=localhost
|
||||
DB_CHAT_DATABASE=poweron_chat_dev
|
||||
DB_CHAT_USER=poweron_dev
|
||||
DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
DB_CHAT_PORT=5432
|
||||
|
||||
# Database Configuration Management
|
||||
# JSON File Storage (current)
|
||||
# DB_MANAGEMENT_HOST=D:/Temp/_powerondb
|
||||
# DB_MANAGEMENT_DATABASE=management
|
||||
# DB_MANAGEMENT_USER=dev_user
|
||||
# DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_MANAGEMENT_HOST=localhost
|
||||
DB_MANAGEMENT_DATABASE=poweron_management_dev
|
||||
DB_MANAGEMENT_USER=poweron_dev
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
DB_MANAGEMENT_PORT=5432
|
||||
|
||||
# Security Configuration
|
||||
APP_JWT_SECRET_SECRET=dev_jwt_secret_token
|
||||
APP_TOKEN_EXPIRY=300
|
||||
|
||||
# CORS Configuration
|
||||
APP_ALLOWED_ORIGINS=http://localhost:8080,https://playground.poweron-center.net
|
||||
|
||||
# Logging configuration
|
||||
APP_LOGGING_LOG_LEVEL = DEBUG
|
||||
APP_LOGGING_LOG_FILE = poweron.log
|
||||
APP_LOGGING_FORMAT = %(asctime)s - %(levelname)s - %(name)s - %(message)s
|
||||
APP_LOGGING_DATE_FORMAT = %Y-%m-%d %H:%M:%S
|
||||
APP_LOGGING_CONSOLE_ENABLED = True
|
||||
APP_LOGGING_FILE_ENABLED = True
|
||||
APP_LOGGING_ROTATION_SIZE = 10485760
|
||||
APP_LOGGING_BACKUP_COUNT = 5
|
||||
|
||||
# Service Redirects
|
||||
Service_MSFT_REDIRECT_URI = http://localhost:8000/api/msft/auth/callback
|
||||
Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback
|
||||
48
env_int.env
48
env_int.env
|
|
@ -6,22 +6,46 @@ APP_ENV_LABEL = Integration Instance
|
|||
APP_API_URL = https://gateway-int.poweron-center.net
|
||||
|
||||
# Database Configuration Application
|
||||
DB_APP_HOST=/home/_powerondb
|
||||
DB_APP_DATABASE=app
|
||||
DB_APP_USER=dev_user
|
||||
DB_APP_PASSWORD_SECRET=dev_password
|
||||
# JSON File Storage (current)
|
||||
# DB_APP_HOST=/home/_powerondb
|
||||
# DB_APP_DATABASE=app
|
||||
# DB_APP_USER=dev_user
|
||||
# DB_APP_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_APP_HOST=gateway-int-db.poweron-center.net
|
||||
DB_APP_DATABASE=poweron_app_int
|
||||
DB_APP_USER=poweron_int
|
||||
DB_APP_PASSWORD_SECRET=int_password_secure
|
||||
DB_APP_PORT=5432
|
||||
|
||||
# Database Configuration Chat
|
||||
DB_CHAT_HOST=/home/_powerondb
|
||||
DB_CHAT_DATABASE=chat
|
||||
DB_CHAT_USER=dev_user
|
||||
DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
# JSON File Storage (current)
|
||||
# DB_CHAT_HOST=/home/_powerondb
|
||||
# DB_CHAT_DATABASE=chat
|
||||
# DB_CHAT_USER=dev_user
|
||||
# DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_CHAT_HOST=gateway-int-db.poweron-center.net
|
||||
DB_CHAT_DATABASE=poweron_chat_int
|
||||
DB_CHAT_USER=poweron_int
|
||||
DB_CHAT_PASSWORD_SECRET=int_password_secure
|
||||
DB_CHAT_PORT=5432
|
||||
|
||||
# Database Configuration Management
|
||||
DB_MANAGEMENT_HOST=/home/_powerondb
|
||||
DB_MANAGEMENT_DATABASE=management
|
||||
DB_MANAGEMENT_USER=dev_user
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
# JSON File Storage (current)
|
||||
# DB_MANAGEMENT_HOST=/home/_powerondb
|
||||
# DB_MANAGEMENT_DATABASE=management
|
||||
# DB_MANAGEMENT_USER=dev_user
|
||||
# DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_MANAGEMENT_HOST=gateway-int-db.poweron-center.net
|
||||
DB_MANAGEMENT_DATABASE=poweron_management_int
|
||||
DB_MANAGEMENT_USER=poweron_int
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=int_password_secure
|
||||
DB_MANAGEMENT_PORT=5432
|
||||
|
||||
# Security Configuration
|
||||
APP_JWT_SECRET_SECRET=dev_jwt_secret_token
|
||||
|
|
|
|||
48
env_prod.env
48
env_prod.env
|
|
@ -6,22 +6,46 @@ APP_ENV_LABEL = Production Instance
|
|||
APP_API_URL = https://gateway.poweron-center.net
|
||||
|
||||
# Database Configuration Application
|
||||
DB_APP_HOST=/home/_powerondb
|
||||
DB_APP_DATABASE=app
|
||||
DB_APP_USER=dev_user
|
||||
DB_APP_PASSWORD_SECRET=dev_password
|
||||
# JSON File Storage (current)
|
||||
# DB_APP_HOST=/home/_powerondb
|
||||
# DB_APP_DATABASE=app
|
||||
# DB_APP_USER=dev_user
|
||||
# DB_APP_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_APP_HOST=gateway-prod-server.postgres.database.azure.com
|
||||
DB_APP_DATABASE=gateway-app
|
||||
DB_APP_USER=gzxxmcrdhn
|
||||
DB_APP_PASSWORD_SECRET=prod_password_very_secure.2025
|
||||
DB_APP_PORT=5432
|
||||
|
||||
# Database Configuration Chat
|
||||
DB_CHAT_HOST=/home/_powerondb
|
||||
DB_CHAT_DATABASE=chat
|
||||
DB_CHAT_USER=dev_user
|
||||
DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
# JSON File Storage (current)
|
||||
# DB_CHAT_HOST=/home/_powerondb
|
||||
# DB_CHAT_DATABASE=chat
|
||||
# DB_CHAT_USER=gzxxmcrdhn
|
||||
# DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_CHAT_HOST=gateway-prod-server.postgres.database.azure.com
|
||||
DB_CHAT_DATABASE=gateway-chat
|
||||
DB_CHAT_USER=poweron_prod
|
||||
DB_CHAT_PASSWORD_SECRET=prod_password_very_secure.2025
|
||||
DB_CHAT_PORT=5432
|
||||
|
||||
# Database Configuration Management
|
||||
DB_MANAGEMENT_HOST=/home/_powerondb
|
||||
DB_MANAGEMENT_DATABASE=management
|
||||
DB_MANAGEMENT_USER=dev_user
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
# JSON File Storage (current)
|
||||
# DB_MANAGEMENT_HOST=/home/_powerondb
|
||||
# DB_MANAGEMENT_DATABASE=gateway-management
|
||||
# DB_MANAGEMENT_USER=gzxxmcrdhn
|
||||
# DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_MANAGEMENT_HOST=gateway-prod-server.postgres.database.azure.com
|
||||
DB_MANAGEMENT_DATABASE=gateway-management
|
||||
DB_MANAGEMENT_USER=poweron_prod
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=prod_password_very_secure.2025
|
||||
DB_MANAGEMENT_PORT=5432
|
||||
|
||||
# Security Configuration
|
||||
APP_JWT_SECRET_SECRET=dev_jwt_secret_token
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class DocumentGenerator:
|
|||
logger.error(f"Error processing single document: {str(e)}")
|
||||
return None
|
||||
|
||||
def createDocumentsFromActionResult(self, action_result, action, workflow) -> List[Any]:
|
||||
def createDocumentsFromActionResult(self, action_result, action, workflow, message_id=None) -> List[Any]:
|
||||
"""
|
||||
Create actual document objects from action result and store them in the system.
|
||||
Returns a list of created document objects with proper workflow context.
|
||||
|
|
@ -103,7 +103,8 @@ class DocumentGenerator:
|
|||
fileName=document_name,
|
||||
mimeType=mime_type,
|
||||
content=content,
|
||||
base64encoded=False
|
||||
base64encoded=False,
|
||||
messageId=message_id
|
||||
)
|
||||
if document:
|
||||
# Set workflow context on the document if possible
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ class HandlingTasks:
|
|||
"taskProgress": "pending"
|
||||
}
|
||||
|
||||
message = self.chatInterface.createWorkflowMessage(message_data)
|
||||
message = self.chatInterface.createMessage(message_data)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
@ -492,7 +492,7 @@ class HandlingTasks:
|
|||
if task_step.userMessage:
|
||||
task_start_message["message"] += f"\n\n💬 {task_step.userMessage}"
|
||||
|
||||
message = self.chatInterface.createWorkflowMessage(task_start_message)
|
||||
message = self.chatInterface.createMessage(task_start_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
logger.info(f"Task start message created for task {task_index}")
|
||||
|
|
@ -569,7 +569,7 @@ class HandlingTasks:
|
|||
"actionNumber": action_number
|
||||
})
|
||||
|
||||
message = self.chatInterface.createWorkflowMessage(action_start_message)
|
||||
message = self.chatInterface.createMessage(action_start_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
logger.info(f"Action start message created for action {action_number}")
|
||||
|
|
@ -623,7 +623,7 @@ class HandlingTasks:
|
|||
"taskProgress": "success"
|
||||
}
|
||||
|
||||
message = self.chatInterface.createWorkflowMessage(task_completion_message)
|
||||
message = self.chatInterface.createMessage(task_completion_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
logger.info(f"Task completion message created for task {task_index}")
|
||||
|
|
@ -715,7 +715,7 @@ class HandlingTasks:
|
|||
"taskProgress": "retry"
|
||||
}
|
||||
|
||||
message = self.chatInterface.createWorkflowMessage(retry_message)
|
||||
message = self.chatInterface.createMessage(retry_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
@ -768,7 +768,7 @@ class HandlingTasks:
|
|||
}
|
||||
|
||||
try:
|
||||
message = self.chatInterface.createWorkflowMessage(message_data)
|
||||
message = self.chatInterface.createMessage(message_data)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
logger.info(f"Created user-facing retry message for failed task: {task_step.objective}")
|
||||
|
|
@ -822,7 +822,7 @@ class HandlingTasks:
|
|||
}
|
||||
|
||||
try:
|
||||
message = self.chatInterface.createWorkflowMessage(message_data)
|
||||
message = self.chatInterface.createMessage(message_data)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
logger.info(f"Created user-facing error message for failed task: {task_step.objective}")
|
||||
|
|
@ -1030,8 +1030,11 @@ class HandlingTasks:
|
|||
if "execParameters" not in actionData:
|
||||
actionData["execParameters"] = {}
|
||||
|
||||
# Use generic field separation based on TaskAction model
|
||||
simple_fields, object_fields = self.chatInterface._separate_object_fields(TaskAction, actionData)
|
||||
|
||||
# Create action in database
|
||||
createdAction = self.chatInterface.db.recordCreate("taskActions", actionData)
|
||||
createdAction = self.chatInterface.db.recordCreate(TaskAction, simple_fields)
|
||||
|
||||
# Convert to TaskAction model
|
||||
return TaskAction(
|
||||
|
|
@ -1095,27 +1098,36 @@ class HandlingTasks:
|
|||
)
|
||||
result_label = action.execResultLabel
|
||||
|
||||
# Process documents from the action result
|
||||
created_documents = []
|
||||
if result.success:
|
||||
created_documents = self.documentGenerator.createDocumentsFromActionResult(result, action, workflow)
|
||||
action.setSuccess()
|
||||
# Extract result text from documents if available, otherwise use empty string
|
||||
action.result = ""
|
||||
if result.documents and len(result.documents) > 0:
|
||||
# Try to get text content from the first document
|
||||
first_doc = result.documents[0]
|
||||
if isinstance(first_doc.documentData, dict):
|
||||
action.result = first_doc.documentData.get("result", "")
|
||||
elif isinstance(first_doc.documentData, str):
|
||||
action.result = first_doc.documentData
|
||||
# Preserve the action's execResultLabel for document routing
|
||||
# Action methods should NOT return resultLabel - this is managed by the action handler
|
||||
if not action.execResultLabel:
|
||||
logger.warning(f"Action {action.execMethod}.{action.execAction} has no execResultLabel set")
|
||||
# Always use the action's execResultLabel for message creation to ensure proper document routing
|
||||
message_result_label = action.execResultLabel
|
||||
await self.createActionMessage(action, result, workflow, message_result_label, created_documents, task_step, task_index)
|
||||
# Process documents from the action result
|
||||
created_documents = []
|
||||
if result.success:
|
||||
action.setSuccess()
|
||||
# Extract result text from documents if available, otherwise use empty string
|
||||
action.result = ""
|
||||
if result.documents and len(result.documents) > 0:
|
||||
# Try to get text content from the first document
|
||||
first_doc = result.documents[0]
|
||||
if isinstance(first_doc.documentData, dict):
|
||||
action.result = first_doc.documentData.get("result", "")
|
||||
elif isinstance(first_doc.documentData, str):
|
||||
action.result = first_doc.documentData
|
||||
# Preserve the action's execResultLabel for document routing
|
||||
# Action methods should NOT return resultLabel - this is managed by the action handler
|
||||
if not action.execResultLabel:
|
||||
logger.warning(f"Action {action.execMethod}.{action.execAction} has no execResultLabel set")
|
||||
# Always use the action's execResultLabel for message creation to ensure proper document routing
|
||||
message_result_label = action.execResultLabel
|
||||
|
||||
# Create message first to get messageId, then create documents with messageId
|
||||
message = await self.createActionMessage(action, result, workflow, message_result_label, [], task_step, task_index)
|
||||
if message:
|
||||
# Now create documents with the messageId
|
||||
created_documents = self.documentGenerator.createDocumentsFromActionResult(result, action, workflow, message.id)
|
||||
# Update the message with the created documents
|
||||
if created_documents:
|
||||
message.documents = created_documents
|
||||
# Update the message in the database
|
||||
self.chatInterface.updateMessage(message.id, {"documents": [doc.to_dict() for doc in created_documents]})
|
||||
|
||||
# Log action results
|
||||
logger.info(f"Action completed successfully")
|
||||
|
|
@ -1138,10 +1150,10 @@ class HandlingTasks:
|
|||
logger.error(f"Action failed: {result.error}")
|
||||
|
||||
# ⚠️ IMPORTANT: Create error message for failed actions so user can see what went wrong
|
||||
await self.createActionMessage(action, result, workflow, result_label, [], task_step, task_index)
|
||||
message = await self.createActionMessage(action, result, workflow, result_label, [], task_step, task_index)
|
||||
|
||||
# Create database log entry for action failure
|
||||
self.chatInterface.createWorkflowLog({
|
||||
self.chatInterface.createLog({
|
||||
"workflowId": workflow.id,
|
||||
"message": f"❌ **Task {task_num}**\n\n❌ **Action {action_num}/{total_actions}** failed: {result.error}",
|
||||
"type": "error"
|
||||
|
|
@ -1237,14 +1249,17 @@ class HandlingTasks:
|
|||
logger.info(f"Creating ERROR message: {message_text}")
|
||||
logger.info(f"Message data: {message_data}")
|
||||
|
||||
message = self.chatInterface.createWorkflowMessage(message_data)
|
||||
message = self.chatInterface.createMessage(message_data)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
logger.info(f"Message created: {action.execMethod}.{action.execAction}")
|
||||
return message
|
||||
else:
|
||||
logger.error(f"Failed to create workflow message for action {action.execMethod}.{action.execAction}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating action message: {str(e)}")
|
||||
return None
|
||||
|
||||
# --- Helper validation methods ---
|
||||
|
||||
|
|
|
|||
|
|
@ -920,7 +920,7 @@ Please provide a comprehensive summary of this conversation."""
|
|||
logger.error(f"Error during document access recovery for {document.id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def createDocument(self, fileName: str, mimeType: str, content: str, base64encoded: bool = True) -> ChatDocument:
|
||||
def createDocument(self, fileName: str, mimeType: str, content: str, base64encoded: bool = True, messageId: str = None) -> ChatDocument:
|
||||
"""Create document with file in one step - handles file creation internally"""
|
||||
# Convert content to bytes based on base64 flag
|
||||
if base64encoded:
|
||||
|
|
@ -948,6 +948,7 @@ Please provide a comprehensive summary of this conversation."""
|
|||
# Create document with all file attributes copied
|
||||
document = ChatDocument(
|
||||
id=str(uuid.uuid4()),
|
||||
messageId=messageId or "", # Use provided messageId or empty string as fallback
|
||||
fileId=file_item.id,
|
||||
fileName=file_info.get("fileName", fileName),
|
||||
fileSize=file_info.get("size", 0),
|
||||
|
|
@ -1060,7 +1061,7 @@ Please provide a comprehensive summary of this conversation."""
|
|||
logger.error(f"Error executing method {methodName}.{actionName}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def processFileIds(self, fileIds: List[str]) -> List[ChatDocument]:
|
||||
async def processFileIds(self, fileIds: List[str], messageId: str = None) -> List[ChatDocument]:
|
||||
"""Process file IDs from existing files and return ChatDocument objects"""
|
||||
documents = []
|
||||
for fileId in fileIds:
|
||||
|
|
@ -1071,6 +1072,7 @@ Please provide a comprehensive summary of this conversation."""
|
|||
# Create document directly with all file attributes
|
||||
document = ChatDocument(
|
||||
id=str(uuid.uuid4()),
|
||||
messageId=messageId or "", # Use provided messageId or empty string as fallback
|
||||
fileId=fileId,
|
||||
fileName=fileInfo.get("fileName", "unknown"),
|
||||
fileSize=fileInfo.get("size", 0),
|
||||
|
|
|
|||
|
|
@ -33,9 +33,11 @@ class DatabaseConnector:
|
|||
# Set userId (default to empty string if None)
|
||||
self.userId = userId if userId is not None else ""
|
||||
|
||||
# Ensure the database directory exists
|
||||
# Initialize database system
|
||||
self.initDbSystem()
|
||||
|
||||
# Set up database folder path
|
||||
self.dbFolder = os.path.join(self.dbHost, self.dbDatabase)
|
||||
os.makedirs(self.dbFolder, exist_ok=True)
|
||||
|
||||
# Cache for loaded data
|
||||
self._tablesCache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
|
@ -52,6 +54,17 @@ class DatabaseConnector:
|
|||
|
||||
logger.debug(f"Context: userId={self.userId}")
|
||||
|
||||
def initDbSystem(self):
|
||||
"""Initialize the database system - creates necessary directories and structure."""
|
||||
try:
|
||||
# Ensure the database directory exists
|
||||
self.dbFolder = os.path.join(self.dbHost, self.dbDatabase)
|
||||
os.makedirs(self.dbFolder, exist_ok=True)
|
||||
logger.info(f"Database system initialized: {self.dbFolder}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database system: {e}")
|
||||
raise
|
||||
|
||||
def _initializeSystemTable(self):
|
||||
"""Initializes the system table if it doesn't exist yet."""
|
||||
systemTablePath = self._getTablePath(self._systemTableName)
|
||||
|
|
@ -652,8 +665,14 @@ class DatabaseConnector:
|
|||
except Exception as release_error:
|
||||
logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
|
||||
|
||||
def getInitialId(self, table: str) -> Optional[str]:
|
||||
def getInitialId(self, table_or_model) -> Optional[str]:
|
||||
"""Returns the initial ID for a table."""
|
||||
# Handle both string table names (legacy) and model classes (new)
|
||||
if isinstance(table_or_model, str):
|
||||
table = table_or_model
|
||||
else:
|
||||
table = table_or_model.__name__
|
||||
|
||||
systemData = self._loadSystemTable()
|
||||
initialId = systemData.get(table)
|
||||
logger.debug(f"Initial ID for table '{table}': {initialId}")
|
||||
|
|
|
|||
840
modules/connectors/connectorDbPostgre.py
Normal file
840
modules/connectors/connectorDbPostgre.py
Normal file
|
|
@ -0,0 +1,840 @@
|
|||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Union, get_origin, get_args
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
import time
|
||||
|
||||
from modules.shared.attributeUtils import to_dict
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.interfaceAppModel import SystemTable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# No mapping needed - table name = Pydantic model name exactly
|
||||
|
||||
def _get_model_fields(model_class) -> Dict[str, str]:
|
||||
"""Get all fields from Pydantic model and map to SQL types."""
|
||||
if not hasattr(model_class, '__fields__'):
|
||||
return {}
|
||||
|
||||
fields = {}
|
||||
for field_name, field_info in model_class.__fields__.items():
|
||||
field_type = field_info.type_
|
||||
|
||||
# Check for JSONB fields (Dict, List, or complex types)
|
||||
if (field_type == dict or
|
||||
field_type == list or
|
||||
(hasattr(field_type, '__origin__') and field_type.__origin__ in (dict, list)) or
|
||||
field_name in ['execParameters', 'expectedDocumentFormats', 'resultDocuments', 'logs', 'messages', 'stats', 'tasks']):
|
||||
fields[field_name] = 'JSONB'
|
||||
# Simple type mapping
|
||||
elif field_type in (str, type(None)) or (get_origin(field_type) is Union and type(None) in get_args(field_type)):
|
||||
fields[field_name] = 'TEXT'
|
||||
elif field_type == int:
|
||||
fields[field_name] = 'INTEGER'
|
||||
elif field_type == float:
|
||||
fields[field_name] = 'REAL'
|
||||
elif field_type == bool:
|
||||
fields[field_name] = 'BOOLEAN'
|
||||
else:
|
||||
fields[field_name] = 'TEXT' # Default to TEXT
|
||||
|
||||
return fields
|
||||
|
||||
# No caching needed with proper database
|
||||
|
||||
class DatabaseConnector:
|
||||
"""
|
||||
A connector for PostgreSQL-based data storage.
|
||||
Provides generic database operations without user/mandate filtering.
|
||||
Uses PostgreSQL with JSONB columns for flexible data storage.
|
||||
"""
|
||||
def __init__(self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None):
|
||||
# Store the input parameters
|
||||
self.dbHost = dbHost
|
||||
self.dbDatabase = dbDatabase
|
||||
self.dbUser = dbUser
|
||||
self.dbPassword = dbPassword
|
||||
self.dbPort = dbPort
|
||||
|
||||
# Set userId (default to empty string if None)
|
||||
self.userId = userId if userId is not None else ""
|
||||
|
||||
# Initialize database system first (creates database if needed)
|
||||
self.connection = None
|
||||
self.initDbSystem()
|
||||
|
||||
# No caching needed with proper database - PostgreSQL handles performance
|
||||
|
||||
# Thread safety
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Initialize system table
|
||||
self._systemTableName = "_system"
|
||||
self._initializeSystemTable()
|
||||
|
||||
logger.debug(f"Context: userId={self.userId}")
|
||||
|
||||
def initDbSystem(self):
|
||||
"""Initialize the database system - creates database and tables."""
|
||||
try:
|
||||
# Create database if it doesn't exist
|
||||
self._create_database_if_not_exists()
|
||||
|
||||
# Create tables
|
||||
self._create_tables()
|
||||
|
||||
# Establish connection to the database
|
||||
self._connect()
|
||||
|
||||
logger.info("PostgreSQL database system initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
|
||||
raise
|
||||
|
||||
def _create_database_if_not_exists(self):
|
||||
"""Create the database if it doesn't exist."""
|
||||
try:
|
||||
# Use the configured user for database creation
|
||||
conn = psycopg2.connect(
|
||||
host=self.dbHost,
|
||||
port=self.dbPort,
|
||||
database="postgres",
|
||||
user=self.dbUser,
|
||||
password=self.dbPassword,
|
||||
client_encoding='utf8'
|
||||
)
|
||||
conn.autocommit = True
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Check if database exists
|
||||
cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,))
|
||||
exists = cursor.fetchone()
|
||||
|
||||
if not exists:
|
||||
# Create database
|
||||
cursor.execute(f"CREATE DATABASE {self.dbDatabase}")
|
||||
logger.info(f"Created database: {self.dbDatabase}")
|
||||
else:
|
||||
logger.info(f"Database {self.dbDatabase} already exists")
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"FATAL ERROR: Cannot create database: {e}")
|
||||
logger.error("Database connection failed - application cannot start")
|
||||
raise RuntimeError(f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}")
|
||||
|
||||
|
||||
def _create_tables(self):
|
||||
"""Create only the system table - application tables are created by interfaces."""
|
||||
try:
|
||||
# Use the configured user for table creation
|
||||
conn = psycopg2.connect(
|
||||
host=self.dbHost,
|
||||
port=self.dbPort,
|
||||
database=self.dbDatabase,
|
||||
user=self.dbUser,
|
||||
password=self.dbPassword,
|
||||
client_encoding='utf8'
|
||||
)
|
||||
conn.autocommit = True
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Create only the system table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _system (
|
||||
id SERIAL PRIMARY KEY,
|
||||
table_name VARCHAR(255) UNIQUE NOT NULL,
|
||||
initial_id VARCHAR(255) NOT NULL,
|
||||
_createdAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
_modifiedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
logger.info("System table created successfully")
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
|
||||
logger.error("Database system table creation failed - application cannot start")
|
||||
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
|
||||
|
||||
def _connect(self):
|
||||
"""Establish connection to PostgreSQL database."""
|
||||
try:
|
||||
# Use configured user for main connection with proper parameter handling
|
||||
self.connection = psycopg2.connect(
|
||||
host=self.dbHost,
|
||||
port=self.dbPort,
|
||||
database=self.dbDatabase,
|
||||
user=self.dbUser,
|
||||
password=self.dbPassword,
|
||||
client_encoding='utf8',
|
||||
cursor_factory=psycopg2.extras.RealDictCursor
|
||||
)
|
||||
self.connection.autocommit = False # Use transactions
|
||||
logger.info(f"Connected to PostgreSQL database: {self.dbDatabase}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
raise
|
||||
|
||||
def _ensure_connection(self):
|
||||
"""Ensure database connection is alive, reconnect if necessary."""
|
||||
try:
|
||||
if self.connection is None or self.connection.closed:
|
||||
self._connect()
|
||||
else:
|
||||
# Test connection with a simple query
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
except Exception as e:
|
||||
logger.warning(f"Connection lost, reconnecting: {e}")
|
||||
self._connect()
|
||||
|
||||
def _initializeSystemTable(self):
|
||||
"""Initializes the system table if it doesn't exist yet."""
|
||||
try:
|
||||
# First ensure the system table exists
|
||||
self._ensureTableExists(SystemTable)
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
# Check if system table has any data
|
||||
cursor.execute('SELECT COUNT(*) FROM "_system"')
|
||||
row = cursor.fetchone()
|
||||
count = row['count'] if row else 0
|
||||
|
||||
self.connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing system table: {e}")
|
||||
self.connection.rollback()
|
||||
raise
|
||||
|
||||
def _loadSystemTable(self) -> Dict[str, str]:
|
||||
"""Loads the system table with the initial IDs."""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
|
||||
rows = cursor.fetchall()
|
||||
|
||||
system_data = {}
|
||||
for row in rows:
|
||||
system_data[row['table_name']] = row['initial_id']
|
||||
|
||||
return system_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading system table: {e}")
|
||||
return {}
|
||||
|
||||
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
|
||||
"""Saves the system table with the initial IDs."""
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
# Clear existing data
|
||||
cursor.execute('DELETE FROM "_system"')
|
||||
|
||||
# Insert new data
|
||||
for table_name, initial_id in data.items():
|
||||
cursor.execute("""
|
||||
INSERT INTO "_system" ("table_name", "initial_id", "_modifiedAt")
|
||||
VALUES (%s, %s, CURRENT_TIMESTAMP)
|
||||
""", (table_name, initial_id))
|
||||
|
||||
self.connection.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving system table: {e}")
|
||||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
def _ensureSystemTableExists(self) -> bool:
|
||||
"""Ensures the system table exists, creates it if it doesn't."""
|
||||
try:
|
||||
self._ensure_connection()
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
# Check if system table exists
|
||||
cursor.execute("SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,))
|
||||
exists = cursor.fetchone()['count'] > 0
|
||||
|
||||
if not exists:
|
||||
# Create system table
|
||||
cursor.execute(f"""
|
||||
CREATE TABLE "{self._systemTableName}" (
|
||||
"table_name" VARCHAR(255) PRIMARY KEY,
|
||||
"initial_id" VARCHAR(255),
|
||||
"_createdAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
"_modifiedAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
logger.info("System table created successfully")
|
||||
else:
|
||||
# Check if we need to add missing columns to existing table
|
||||
cursor.execute("""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = %s AND table_schema = 'public'
|
||||
""", (self._systemTableName,))
|
||||
existing_columns = [row['column_name'] for row in cursor.fetchall()]
|
||||
|
||||
if '_modifiedAt' not in existing_columns:
|
||||
cursor.execute(f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "_modifiedAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP')
|
||||
logger.info("Added _modifiedAt column to existing system table")
|
||||
|
||||
logger.debug("System table already exists")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring system table exists: {e}")
|
||||
return False
|
||||
|
||||
def _ensureTableExists(self, model_class: type) -> bool:
|
||||
"""Ensures a table exists, creates it if it doesn't."""
|
||||
table = model_class.__name__
|
||||
|
||||
if table == "SystemTable":
|
||||
# Handle system table specially - it uses _system as the actual table name
|
||||
return self._ensureSystemTableExists()
|
||||
|
||||
try:
|
||||
self._ensure_connection()
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
# Check if table exists by querying information_schema with case-insensitive search
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
|
||||
''', (table,))
|
||||
exists = cursor.fetchone()['count'] > 0
|
||||
logger.debug(f"Table {table} exists check: {exists}")
|
||||
|
||||
if not exists:
|
||||
# Create table from Pydantic model
|
||||
logger.debug(f"Creating table {table} with model {model_class}")
|
||||
self._create_table_from_model(cursor, table, model_class)
|
||||
logger.info(f"Created table '{table}' with columns from Pydantic model")
|
||||
|
||||
self.connection.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring table {table} exists: {e}")
|
||||
if hasattr(self, 'connection') and self.connection:
|
||||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
|
||||
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
||||
"""Create table with columns matching Pydantic model fields."""
|
||||
fields = _get_model_fields(model_class)
|
||||
logger.debug(f"Creating table {table} with fields: {fields}")
|
||||
|
||||
# Build column definitions with quoted identifiers to preserve exact case
|
||||
columns = ['"id" VARCHAR(255) PRIMARY KEY']
|
||||
for field_name, sql_type in fields.items():
|
||||
if field_name != 'id': # Skip id, already defined
|
||||
columns.append(f'"{field_name}" {sql_type}')
|
||||
|
||||
# Add metadata columns
|
||||
columns.extend([
|
||||
'"_createdAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP',
|
||||
'"_modifiedAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP',
|
||||
'"_createdBy" VARCHAR(255)',
|
||||
'"_modifiedBy" VARCHAR(255)'
|
||||
])
|
||||
|
||||
# Create table
|
||||
sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})'
|
||||
logger.debug(f"Executing SQL: {sql}")
|
||||
cursor.execute(sql)
|
||||
|
||||
# Create indexes for foreign keys
|
||||
for field_name in fields:
|
||||
if field_name.endswith('Id') and field_name != 'id':
|
||||
cursor.execute(f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")')
|
||||
|
||||
|
||||
def _save_record(self, cursor, table: str, recordId: str, record: Dict[str, Any], model_class: type) -> None:
|
||||
"""Save record to normalized table with explicit columns."""
|
||||
# Get columns from Pydantic model instead of database schema
|
||||
fields = _get_model_fields(model_class)
|
||||
columns = ['id'] + [field for field in fields.keys() if field != 'id'] + ['_createdAt', '_createdBy', '_modifiedAt', '_modifiedBy']
|
||||
|
||||
logger.debug(f"Table {table} columns: {columns}")
|
||||
logger.debug(f"Record data: {record}")
|
||||
|
||||
if not columns:
|
||||
logger.error(f"No columns found for table {table}")
|
||||
return
|
||||
|
||||
# Filter record data to only include columns that exist in the table
|
||||
filtered_record = {k: v for k, v in record.items() if k in columns}
|
||||
|
||||
# Ensure id is set
|
||||
filtered_record['id'] = recordId
|
||||
|
||||
# Prepare values in the correct order
|
||||
values = []
|
||||
for col in columns:
|
||||
value = filtered_record.get(col)
|
||||
|
||||
# Convert timestamp fields to proper PostgreSQL format
|
||||
if col in ['_createdAt', '_modifiedAt'] and value is not None:
|
||||
if isinstance(value, (int, float)):
|
||||
# Convert Unix timestamp to PostgreSQL timestamp
|
||||
from datetime import datetime
|
||||
value = datetime.fromtimestamp(value)
|
||||
elif isinstance(value, str):
|
||||
# If it's already a string, try to parse it
|
||||
try:
|
||||
from datetime import datetime
|
||||
value = datetime.fromtimestamp(float(value))
|
||||
except:
|
||||
pass # Keep as string if parsing fails
|
||||
|
||||
# Convert enum values to their string representation
|
||||
elif hasattr(value, 'value'):
|
||||
value = value.value
|
||||
|
||||
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
|
||||
elif col in fields and fields[col] == 'JSONB' and value is not None:
|
||||
import json
|
||||
if isinstance(value, (dict, list)):
|
||||
# Convert Python objects to JSON string for PostgreSQL JSONB
|
||||
value = json.dumps(value)
|
||||
elif isinstance(value, str):
|
||||
# Validate that it's valid JSON, if not, try to parse and re-serialize
|
||||
try:
|
||||
# Test if it's already valid JSON
|
||||
json.loads(value)
|
||||
# If successful, keep as is
|
||||
pass
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If not valid JSON, convert to JSON string
|
||||
value = json.dumps(value)
|
||||
else:
|
||||
# Convert other types to JSON
|
||||
value = json.dumps(value)
|
||||
|
||||
values.append(value)
|
||||
|
||||
logger.debug(f"Values to insert: {values}")
|
||||
|
||||
# Build INSERT/UPDATE with quoted identifiers
|
||||
col_names = ', '.join([f'"{col}"' for col in columns])
|
||||
placeholders = ', '.join(['%s'] * len(columns))
|
||||
updates = ', '.join([f'"{col}" = EXCLUDED."{col}"' for col in columns[1:] if col not in ['_createdAt', '_createdBy']])
|
||||
|
||||
sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}'
|
||||
logger.debug(f"SQL: {sql}")
|
||||
|
||||
cursor.execute(sql, values)
|
||||
|
||||
def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Loads a single record from the normalized table."""
|
||||
table = model_class.__name__
|
||||
|
||||
try:
|
||||
if not self._ensureTableExists(model_class):
|
||||
return None
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
# Convert row to dict and handle JSONB fields
|
||||
record = dict(row)
|
||||
fields = _get_model_fields(model_class)
|
||||
|
||||
# Parse JSONB fields back to Python objects
|
||||
for field_name, field_type in fields.items():
|
||||
if field_type == 'JSONB' and field_name in record and record[field_name] is not None:
|
||||
import json
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
record[field_name] = json.loads(record[field_name])
|
||||
elif isinstance(record[field_name], (dict, list)):
|
||||
# Already a Python object, keep as is
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON
|
||||
record[field_name] = json.loads(str(record[field_name]))
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}")
|
||||
pass
|
||||
|
||||
return record
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading record {recordId} from table {table}: {e}")
|
||||
return None
|
||||
|
||||
def _saveRecord(self, model_class: type, recordId: str, record: Dict[str, Any]) -> bool:
|
||||
"""Saves a single record to the table."""
|
||||
table = model_class.__name__
|
||||
|
||||
try:
|
||||
if not self._ensureTableExists(model_class):
|
||||
return False
|
||||
|
||||
recordId = str(recordId)
|
||||
if "id" in record and str(record["id"]) != recordId:
|
||||
raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}")
|
||||
|
||||
# Add metadata
|
||||
currentTime = get_utc_timestamp()
|
||||
if "_createdAt" not in record:
|
||||
record["_createdAt"] = currentTime
|
||||
record["_createdBy"] = self.userId
|
||||
record["_modifiedAt"] = currentTime
|
||||
record["_modifiedBy"] = self.userId
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
self._save_record(cursor, table, recordId, record, model_class)
|
||||
|
||||
self.connection.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving record {recordId} to table {table}: {e}")
|
||||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
|
||||
"""Loads all records from a normalized table."""
|
||||
table = model_class.__name__
|
||||
|
||||
if table == self._systemTableName:
|
||||
return self._loadSystemTable()
|
||||
|
||||
try:
|
||||
if not self._ensureTableExists(model_class):
|
||||
return []
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
|
||||
records = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# Handle JSONB fields for all records
|
||||
fields = _get_model_fields(model_class)
|
||||
for record in records:
|
||||
for field_name, field_type in fields.items():
|
||||
if field_type == 'JSONB' and field_name in record and record[field_name] is not None:
|
||||
import json
|
||||
try:
|
||||
if isinstance(record[field_name], str):
|
||||
# Parse JSON string back to Python object
|
||||
record[field_name] = json.loads(record[field_name])
|
||||
elif isinstance(record[field_name], (dict, list)):
|
||||
# Already a Python object, keep as is
|
||||
pass
|
||||
else:
|
||||
# Try to parse as JSON
|
||||
record[field_name] = json.loads(str(record[field_name]))
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
# If parsing fails, keep as string
|
||||
logger.warning(f"Could not parse JSONB field {field_name}, keeping as string: {record[field_name]}")
|
||||
pass
|
||||
|
||||
return records
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading table {table}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _applyRecordFilter(self, records: List[Dict[str, Any]], recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Applies a record filter to the records"""
|
||||
if not recordFilter:
|
||||
return records
|
||||
|
||||
filteredRecords = []
|
||||
|
||||
for record in records:
|
||||
match = True
|
||||
|
||||
for field, value in recordFilter.items():
|
||||
# Check if the field exists
|
||||
if field not in record:
|
||||
match = False
|
||||
break
|
||||
|
||||
# Convert both values to strings for comparison
|
||||
recordValue = str(record[field])
|
||||
filterValue = str(value)
|
||||
|
||||
# Direct string comparison
|
||||
if recordValue != filterValue:
|
||||
match = False
|
||||
break
|
||||
|
||||
if match:
|
||||
filteredRecords.append(record)
|
||||
|
||||
return filteredRecords
|
||||
|
||||
def _registerInitialId(self, table: str, initialId: str) -> bool:
|
||||
"""Registers the initial ID for a table."""
|
||||
try:
|
||||
systemData = self._loadSystemTable()
|
||||
|
||||
if table not in systemData:
|
||||
systemData[table] = initialId
|
||||
success = self._saveSystemTable(systemData)
|
||||
if success:
|
||||
logger.info(f"Initial ID {initialId} for table {table} registered")
|
||||
return success
|
||||
else:
|
||||
# Check if the existing initial ID still exists in the table
|
||||
existingInitialId = systemData[table]
|
||||
records = self.getRecordset(model_class, recordFilter={"id": existingInitialId})
|
||||
if not records:
|
||||
# The initial record no longer exists, update to the new one
|
||||
systemData[table] = initialId
|
||||
success = self._saveSystemTable(systemData)
|
||||
if success:
|
||||
logger.info(f"Initial ID updated from {existingInitialId} to {initialId} for table {table}")
|
||||
return success
|
||||
else:
|
||||
logger.debug(f"Initial ID {existingInitialId} for table {table} already exists and is valid")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering the initial ID for table {table}: {e}")
|
||||
return False
|
||||
|
||||
def _removeInitialId(self, table: str) -> bool:
|
||||
"""Removes the initial ID for a table from the system table."""
|
||||
try:
|
||||
systemData = self._loadSystemTable()
|
||||
|
||||
if table in systemData:
|
||||
del systemData[table]
|
||||
success = self._saveSystemTable(systemData)
|
||||
if success:
|
||||
logger.info(f"Initial ID for table {table} removed from system table")
|
||||
return success
|
||||
return True # If not present, this is not an error
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing initial ID for table {table}: {e}")
|
||||
return False
|
||||
|
||||
def updateContext(self, userId: str) -> None:
|
||||
"""Updates the context of the database connector."""
|
||||
if userId is None:
|
||||
raise ValueError("userId must be provided")
|
||||
|
||||
self.userId = userId
|
||||
logger.info(f"Updated database context: userId={self.userId}")
|
||||
|
||||
# No cache to clear - database handles data consistency
|
||||
|
||||
def clearTableCache(self, model_class: type) -> None:
|
||||
"""No-op: Database handles data consistency automatically."""
|
||||
# No caching with proper database - PostgreSQL handles consistency
|
||||
pass
|
||||
|
||||
# Public API
|
||||
|
||||
def getTables(self) -> List[str]:
|
||||
"""Returns a list of all available tables."""
|
||||
tables = []
|
||||
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name NOT LIKE '_%'
|
||||
ORDER BY table_name
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
tables = [row['table_name'] for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading the database: {e}")
|
||||
|
||||
return tables
|
||||
|
||||
def getFields(self, model_class: type) -> List[str]:
|
||||
"""Returns a list of all fields in a table."""
|
||||
data = self._loadTable(model_class)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
fields = list(data[0].keys()) if data else []
|
||||
|
||||
return fields
|
||||
|
||||
def getSchema(self, model_class: type, language: str = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a schema object for a table with data types and labels."""
|
||||
data = self._loadTable(model_class)
|
||||
|
||||
schema = {}
|
||||
|
||||
if not data:
|
||||
return schema
|
||||
|
||||
firstRecord = data[0]
|
||||
|
||||
for field, value in firstRecord.items():
|
||||
dataType = type(value).__name__
|
||||
label = field
|
||||
|
||||
schema[field] = {
|
||||
"type": dataType,
|
||||
"label": label
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
def getRecordset(self, model_class: type, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""Returns a list of records from a table, filtered by criteria."""
|
||||
table = model_class.__name__
|
||||
|
||||
# If we have specific record IDs in the filter, only load those records
|
||||
if recordFilter and "id" in recordFilter:
|
||||
recordId = recordFilter["id"]
|
||||
record = self._loadRecord(model_class, recordId)
|
||||
if record:
|
||||
records = [record]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# Load all records if no specific ID filter
|
||||
records = self._loadTable(model_class)
|
||||
|
||||
# Apply recordFilter if available
|
||||
if recordFilter:
|
||||
records = self._applyRecordFilter(records, recordFilter)
|
||||
|
||||
# If fieldFilter is available, reduce the fields
|
||||
if fieldFilter and isinstance(fieldFilter, list):
|
||||
result = []
|
||||
for record in records:
|
||||
filteredRecord = {}
|
||||
for field in fieldFilter:
|
||||
if field in record:
|
||||
filteredRecord[field] = record[field]
|
||||
result.append(filteredRecord)
|
||||
return result
|
||||
|
||||
return records
|
||||
|
||||
def recordCreate(self, model_class: type, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
"""Creates a new record in a table based on Pydantic model class."""
|
||||
# If record is a Pydantic model, convert to dict
|
||||
if isinstance(record, BaseModel):
|
||||
record = to_dict(record)
|
||||
elif isinstance(record, dict):
|
||||
record = record.copy()
|
||||
else:
|
||||
raise ValueError("Record must be a Pydantic model or dictionary")
|
||||
|
||||
# Ensure record has an ID
|
||||
if "id" not in record:
|
||||
record["id"] = str(uuid.uuid4())
|
||||
|
||||
# Save record
|
||||
self._saveRecord(model_class, record["id"], record)
|
||||
|
||||
# Check if this is the first record in the table and register as initial ID
|
||||
table = model_class.__name__
|
||||
existingInitialId = self.getInitialId(model_class)
|
||||
if existingInitialId is None:
|
||||
# This is the first record, register it as the initial ID
|
||||
self._registerInitialId(table, record["id"])
|
||||
logger.info(f"Registered initial ID {record['id']} for table {table}")
|
||||
|
||||
return record
|
||||
|
||||
def recordModify(self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
"""Modifies an existing record in a table based on Pydantic model class."""
|
||||
# Load existing record
|
||||
existingRecord = self._loadRecord(model_class, recordId)
|
||||
if not existingRecord:
|
||||
table = model_class.__name__
|
||||
raise ValueError(f"Record {recordId} not found in table {table}")
|
||||
|
||||
# If record is a Pydantic model, convert to dict
|
||||
if isinstance(record, BaseModel):
|
||||
record = to_dict(record)
|
||||
elif isinstance(record, dict):
|
||||
record = record.copy()
|
||||
else:
|
||||
raise ValueError("Record must be a Pydantic model or dictionary")
|
||||
|
||||
# CRITICAL: Ensure we never modify the ID
|
||||
if "id" in record and str(record["id"]) != recordId:
|
||||
logger.error(f"Attempted to modify record ID from {recordId} to {record['id']}")
|
||||
raise ValueError("Cannot modify record ID - it must match the provided recordId")
|
||||
|
||||
# Update existing record with new data
|
||||
existingRecord.update(record)
|
||||
|
||||
# Save updated record
|
||||
self._saveRecord(model_class, recordId, existingRecord)
|
||||
return existingRecord
|
||||
|
||||
def recordDelete(self, model_class: type, recordId: str) -> bool:
|
||||
"""Deletes a record from the table based on Pydantic model class."""
|
||||
table = model_class.__name__
|
||||
|
||||
try:
|
||||
if not self._ensureTableExists(model_class):
|
||||
return False
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
# Check if record exists
|
||||
cursor.execute(f"SELECT id FROM {table} WHERE id = %s", (recordId,))
|
||||
if not cursor.fetchone():
|
||||
return False
|
||||
|
||||
# Check if it's an initial record
|
||||
initialId = self.getInitialId(model_class)
|
||||
if initialId is not None and initialId == recordId:
|
||||
self._removeInitialId(table)
|
||||
logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table")
|
||||
|
||||
# Delete the record
|
||||
cursor.execute(f"DELETE FROM {table} WHERE id = %s", (recordId,))
|
||||
|
||||
# No cache to update - database handles consistency
|
||||
|
||||
self.connection.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
|
||||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
|
||||
def getInitialId(self, model_class: type) -> Optional[str]:
|
||||
"""Returns the initial ID for a table."""
|
||||
table = model_class.__name__
|
||||
systemData = self._loadSystemTable()
|
||||
initialId = systemData.get(table)
|
||||
logger.debug(f"Initial ID for table '{table}': {initialId}")
|
||||
return initialId
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
if hasattr(self, 'connection') and self.connection and not self.connection.closed:
|
||||
self.connection.close()
|
||||
logger.debug("Database connection closed")
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to close connection."""
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
# Ignore errors during cleanup
|
||||
pass
|
||||
|
|
@ -1,178 +0,0 @@
|
|||
import threading
|
||||
import queue
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from .connectorDbJson import DatabaseConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DatabaseConnectorPool:
|
||||
"""
|
||||
A connection pool for DatabaseConnector instances to manage resources efficiently
|
||||
and ensure proper isolation between users.
|
||||
"""
|
||||
|
||||
def __init__(self, max_connections: int = 100, max_idle_time: int = 300):
|
||||
"""
|
||||
Initialize the connection pool.
|
||||
|
||||
Args:
|
||||
max_connections: Maximum number of connections in the pool
|
||||
max_idle_time: Maximum idle time in seconds before connection is considered stale
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self.max_idle_time = max_idle_time
|
||||
self._pool = queue.Queue(maxsize=max_connections)
|
||||
self._created_connections = 0
|
||||
self._lock = threading.Lock()
|
||||
self._connection_times = {} # Track when connections were created
|
||||
|
||||
def _create_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
|
||||
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
|
||||
"""Create a new DatabaseConnector instance."""
|
||||
with self._lock:
|
||||
if self._created_connections >= self.max_connections:
|
||||
raise RuntimeError(f"Maximum connections ({self.max_connections}) exceeded")
|
||||
|
||||
self._created_connections += 1
|
||||
logger.debug(f"Creating new database connector (total: {self._created_connections})")
|
||||
|
||||
connector = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
userId=userId
|
||||
)
|
||||
|
||||
# Track creation time
|
||||
connector_id = id(connector)
|
||||
self._connection_times[connector_id] = time.time()
|
||||
|
||||
return connector
|
||||
|
||||
def get_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
|
||||
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
|
||||
"""
|
||||
Get a database connector from the pool or create a new one.
|
||||
|
||||
Args:
|
||||
dbHost: Database host path
|
||||
dbDatabase: Database name
|
||||
dbUser: Database user (optional)
|
||||
dbPassword: Database password (optional)
|
||||
userId: User ID for context (optional)
|
||||
|
||||
Returns:
|
||||
DatabaseConnector instance
|
||||
"""
|
||||
try:
|
||||
# Try to get an existing connector from the pool
|
||||
connector = self._pool.get_nowait()
|
||||
|
||||
# Check if connector is stale
|
||||
connector_id = id(connector)
|
||||
if connector_id in self._connection_times:
|
||||
idle_time = time.time() - self._connection_times[connector_id]
|
||||
if idle_time > self.max_idle_time:
|
||||
logger.debug(f"Connector {connector_id} is stale (idle: {idle_time}s), creating new one")
|
||||
# Remove stale connector from tracking
|
||||
if connector_id in self._connection_times:
|
||||
del self._connection_times[connector_id]
|
||||
# Create new connector
|
||||
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
|
||||
|
||||
# Update user context if provided
|
||||
if userId is not None:
|
||||
connector.updateContext(userId)
|
||||
|
||||
logger.debug(f"Reusing existing connector {connector_id}")
|
||||
return connector
|
||||
|
||||
except queue.Empty:
|
||||
# Pool is empty, create new connector
|
||||
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
|
||||
|
||||
def return_connector(self, connector: DatabaseConnector) -> None:
|
||||
"""
|
||||
Return a connector to the pool for reuse.
|
||||
|
||||
Args:
|
||||
connector: DatabaseConnector instance to return
|
||||
"""
|
||||
try:
|
||||
# Update connection time
|
||||
connector_id = id(connector)
|
||||
self._connection_times[connector_id] = time.time()
|
||||
|
||||
# Try to return to pool
|
||||
self._pool.put_nowait(connector)
|
||||
logger.debug(f"Returned connector {connector_id} to pool")
|
||||
|
||||
except queue.Full:
|
||||
# Pool is full, discard connector
|
||||
logger.debug(f"Pool full, discarding connector {id(connector)}")
|
||||
with self._lock:
|
||||
self._created_connections -= 1
|
||||
if id(connector) in self._connection_times:
|
||||
del self._connection_times[id(connector)]
|
||||
|
||||
def cleanup_stale_connections(self) -> int:
|
||||
"""
|
||||
Clean up stale connections from the pool.
|
||||
|
||||
Returns:
|
||||
Number of connections cleaned up
|
||||
"""
|
||||
cleaned = 0
|
||||
current_time = time.time()
|
||||
|
||||
# Check all tracked connections
|
||||
stale_connectors = []
|
||||
for connector_id, creation_time in list(self._connection_times.items()):
|
||||
if current_time - creation_time > self.max_idle_time:
|
||||
stale_connectors.append(connector_id)
|
||||
|
||||
# Remove stale connections from tracking
|
||||
for connector_id in stale_connectors:
|
||||
if connector_id in self._connection_times:
|
||||
del self._connection_times[connector_id]
|
||||
cleaned += 1
|
||||
|
||||
logger.debug(f"Cleaned up {cleaned} stale connections")
|
||||
return cleaned
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"max_connections": self.max_connections,
|
||||
"created_connections": self._created_connections,
|
||||
"available_connections": self._pool.qsize(),
|
||||
"tracked_connections": len(self._connection_times)
|
||||
}
|
||||
|
||||
# Global pool instance
|
||||
_connector_pool = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
def get_connector_pool() -> DatabaseConnectorPool:
|
||||
"""Get the global connector pool instance."""
|
||||
global _connector_pool
|
||||
if _connector_pool is None:
|
||||
with _pool_lock:
|
||||
if _connector_pool is None:
|
||||
_connector_pool = DatabaseConnectorPool()
|
||||
return _connector_pool
|
||||
|
||||
def get_connector(dbHost: str, dbDatabase: str, dbUser: str = None,
|
||||
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
|
||||
"""Get a database connector from the global pool."""
|
||||
pool = get_connector_pool()
|
||||
return pool.get_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
|
||||
|
||||
def return_connector(connector: DatabaseConnector) -> None:
|
||||
"""Return a database connector to the global pool."""
|
||||
pool = get_connector_pool()
|
||||
pool.return_connector(connector)
|
||||
|
|
@ -5,7 +5,7 @@ Access control for the Application.
|
|||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from modules.interfaces.interfaceAppModel import UserPrivilege, User
|
||||
from modules.interfaces.interfaceAppModel import UserPrivilege, User, UserInDB, AuthEvent
|
||||
from modules.shared.timezoneUtils import get_utc_now
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -29,28 +29,29 @@ class AppAccess:
|
|||
|
||||
self.db = db
|
||||
|
||||
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified user access management function that filters data based on user privileges
|
||||
and adds access control attributes.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordset: Recordset to filter based on access rules
|
||||
|
||||
Returns:
|
||||
Filtered recordset with access control attributes
|
||||
"""
|
||||
filtered_records = []
|
||||
table_name = model_class.__name__
|
||||
|
||||
# Only SYSADMIN can see mandates
|
||||
if table == "mandates":
|
||||
if table_name == "Mandate":
|
||||
if self.privilege == UserPrivilege.SYSADMIN:
|
||||
filtered_records = recordset
|
||||
else:
|
||||
filtered_records = []
|
||||
# Special handling for users table
|
||||
elif table == "users":
|
||||
elif table_name == "UserInDB":
|
||||
if self.privilege == UserPrivilege.SYSADMIN:
|
||||
# SysAdmin sees all users
|
||||
filtered_records = recordset
|
||||
|
|
@ -61,13 +62,13 @@ class AppAccess:
|
|||
# Regular users only see themselves
|
||||
filtered_records = [r for r in recordset if r.get("id") == self.userId]
|
||||
# Special handling for connections table
|
||||
elif table == "connections":
|
||||
elif table_name == "UserConnection":
|
||||
if self.privilege == UserPrivilege.SYSADMIN:
|
||||
# SysAdmin sees all connections
|
||||
filtered_records = recordset
|
||||
elif self.privilege == UserPrivilege.ADMIN:
|
||||
# Admin sees connections for users in their mandate
|
||||
users: List[Dict[str, Any]] = self.db.getRecordset("users", recordFilter={"mandateId": self.mandateId})
|
||||
users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
|
||||
user_ids: List[str] = [str(u["id"]) for u in users]
|
||||
filtered_records = [r for r in recordset if r.get("userId") in user_ids]
|
||||
else:
|
||||
|
|
@ -89,11 +90,11 @@ class AppAccess:
|
|||
record_id = record.get("id")
|
||||
|
||||
# Set access control flags based on user permissions
|
||||
if table == "mandates":
|
||||
if table_name == "Mandate":
|
||||
record["_hideView"] = False # SYSADMIN can view
|
||||
record["_hideEdit"] = not self.canModify("mandates", record_id)
|
||||
record["_hideDelete"] = not self.canModify("mandates", record_id)
|
||||
elif table == "users":
|
||||
record["_hideEdit"] = not self.canModify(Mandate, record_id)
|
||||
record["_hideDelete"] = not self.canModify(Mandate, record_id)
|
||||
elif table_name == "UserInDB":
|
||||
record["_hideView"] = False # Everyone can view users they have access to
|
||||
# SysAdmin can edit/delete any user
|
||||
if self.privilege == UserPrivilege.SYSADMIN:
|
||||
|
|
@ -107,7 +108,7 @@ class AppAccess:
|
|||
else:
|
||||
record["_hideEdit"] = record.get("id") != self.userId
|
||||
record["_hideDelete"] = True # Regular users cannot delete users
|
||||
elif table == "connections":
|
||||
elif table_name == "UserConnection":
|
||||
# Everyone can view connections they have access to
|
||||
record["_hideView"] = False
|
||||
# SysAdmin can edit/delete any connection
|
||||
|
|
@ -116,7 +117,7 @@ class AppAccess:
|
|||
record["_hideDelete"] = False
|
||||
# Admin can edit/delete connections for users in their mandate
|
||||
elif self.privilege == UserPrivilege.ADMIN:
|
||||
users: List[Dict[str, Any]] = self.db.getRecordset("users", recordFilter={"mandateId": self.mandateId})
|
||||
users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
|
||||
user_ids: List[str] = [str(u["id"]) for u in users]
|
||||
record["_hideEdit"] = record.get("userId") not in user_ids
|
||||
record["_hideDelete"] = record.get("userId") not in user_ids
|
||||
|
|
@ -125,35 +126,37 @@ class AppAccess:
|
|||
record["_hideEdit"] = record.get("userId") != self.userId
|
||||
record["_hideDelete"] = record.get("userId") != self.userId
|
||||
|
||||
elif table == "auth_events":
|
||||
elif table_name == "AuthEvent":
|
||||
# Only show auth events for the current user or if admin
|
||||
if self.privilege in [UserPrivilege.SYSADMIN, UserPrivilege.ADMIN]:
|
||||
record["_hideView"] = False
|
||||
else:
|
||||
record["_hideView"] = record.get("userId") != self.userId
|
||||
record["_hideEdit"] = True # Auth events can't be edited
|
||||
record["_hideDelete"] = not self.canModify("auth_events", record_id)
|
||||
record["_hideDelete"] = not self.canModify(AuthEvent, record_id)
|
||||
else:
|
||||
# Default access control for other tables
|
||||
record["_hideView"] = False
|
||||
record["_hideEdit"] = not self.canModify(table, record_id)
|
||||
record["_hideDelete"] = not self.canModify(table, record_id)
|
||||
record["_hideEdit"] = not self.canModify(model_class, record_id)
|
||||
record["_hideDelete"] = not self.canModify(model_class, record_id)
|
||||
|
||||
return filtered_records
|
||||
|
||||
def canModify(self, table: str, recordId: Optional[str] = None) -> bool:
|
||||
def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Checks if the current user can modify (create/update/delete) records in a table.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordId: Optional record ID for specific record check
|
||||
|
||||
Returns:
|
||||
Boolean indicating permission
|
||||
"""
|
||||
table_name = model_class.__name__
|
||||
|
||||
# For mandates, only SYSADMIN can modify
|
||||
if table == "mandates":
|
||||
if table_name == "Mandate":
|
||||
return self.privilege == UserPrivilege.SYSADMIN
|
||||
|
||||
# System admins can modify anything else
|
||||
|
|
@ -163,17 +166,17 @@ class AppAccess:
|
|||
# Check specific record permissions
|
||||
if recordId is not None:
|
||||
# Get the record to check ownership
|
||||
records: List[Dict[str, Any]] = self.db.getRecordset(table, recordFilter={"id": str(recordId)})
|
||||
records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": str(recordId)})
|
||||
if not records:
|
||||
return False
|
||||
|
||||
record = records[0]
|
||||
|
||||
# Special handling for connections
|
||||
if table == "connections":
|
||||
if table_name == "UserConnection":
|
||||
# Admin can modify connections for users in their mandate
|
||||
if self.privilege == UserPrivilege.ADMIN:
|
||||
users: List[Dict[str, Any]] = self.db.getRecordset("users", recordFilter={"mandateId": self.mandateId})
|
||||
users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
|
||||
user_ids: List[str] = [str(u["id"]) for u in users]
|
||||
return record.get("userId") in user_ids
|
||||
# Users can only modify their own connections
|
||||
|
|
|
|||
|
|
@ -353,4 +353,93 @@ class GoogleToken(Token):
|
|||
class MsftToken(Token):
|
||||
"""Microsoft OAuth token model"""
|
||||
pass
|
||||
|
||||
class AuthEvent(BaseModel, ModelMixin):
|
||||
"""Data model for authentication events"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the auth event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user this event belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
eventType: str = Field(
|
||||
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Unix timestamp when the event occurred",
|
||||
frontend_type="datetime",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
ipAddress: Optional[str] = Field(
|
||||
default=None,
|
||||
description="IP address from which the event originated",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
userAgent: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User agent string from the request",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True,
|
||||
description="Whether the authentication event was successful",
|
||||
frontend_type="boolean",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
details: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Additional details about the event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
# Register labels for AuthEvent
|
||||
register_model_labels(
|
||||
"AuthEvent",
|
||||
{"en": "Authentication Event", "fr": "Événement d'authentification"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"eventType": {"en": "Event Type", "fr": "Type d'événement"},
|
||||
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||
"ipAddress": {"en": "IP Address", "fr": "Adresse IP"},
|
||||
"userAgent": {"en": "User Agent", "fr": "Agent utilisateur"},
|
||||
"success": {"en": "Success", "fr": "Succès"},
|
||||
"details": {"en": "Details", "fr": "Détails"}
|
||||
}
|
||||
)
|
||||
|
||||
class SystemTable(BaseModel, ModelMixin):
|
||||
"""Data model for system table entries"""
|
||||
table_name: str = Field(
|
||||
description="Name of the table",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
initial_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Initial ID for the table",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
|
|
@ -12,15 +12,14 @@ import json
|
|||
from passlib.context import CryptContext
|
||||
import uuid
|
||||
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.connectors.connectorPool import get_connector, return_connector
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp
|
||||
from modules.interfaces.interfaceAppAccess import AppAccess
|
||||
from modules.interfaces.interfaceAppModel import (
|
||||
User, Mandate, UserInDB, UserConnection,
|
||||
AuthAuthority, UserPrivilege,
|
||||
ConnectionStatus, Token
|
||||
ConnectionStatus, Token, AuthEvent
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -81,34 +80,36 @@ class AppObjects:
|
|||
self.db.updateContext(self.userId)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to return connector to pool."""
|
||||
"""Cleanup method to close database connection."""
|
||||
if hasattr(self, 'db') and self.db is not None:
|
||||
try:
|
||||
return_connector(self.db)
|
||||
self.db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error returning connector to pool: {e}")
|
||||
logger.error(f"Error closing database connection: {e}")
|
||||
|
||||
def _initializeDatabase(self):
|
||||
"""Initializes the database connection using connection pool."""
|
||||
"""Initializes the database connection directly."""
|
||||
try:
|
||||
# Get configuration values with defaults
|
||||
dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data")
|
||||
dbDatabase = APP_CONFIG.get("DB_APP_DATABASE", "app")
|
||||
dbUser = APP_CONFIG.get("DB_APP_USER")
|
||||
dbPassword = APP_CONFIG.get("DB_APP_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_APP_PORT", 5432))
|
||||
|
||||
# Ensure the database directory exists
|
||||
os.makedirs(dbHost, exist_ok=True)
|
||||
|
||||
# Get connector from pool with user context
|
||||
self.db = get_connector(
|
||||
# Create database connector directly
|
||||
self.db = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
self.db.initDbSystem()
|
||||
|
||||
logger.info(f"Database initialized successfully for user {self.userId}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
|
|
@ -121,8 +122,8 @@ class AppObjects:
|
|||
|
||||
def _initRootMandate(self):
|
||||
"""Creates the Root mandate if it doesn't exist."""
|
||||
existingMandateId = self.getInitialId("mandates")
|
||||
mandates = self.db.getRecordset("mandates")
|
||||
existingMandateId = self.getInitialId(Mandate)
|
||||
mandates = self.db.getRecordset(Mandate)
|
||||
if existingMandateId is None or not mandates:
|
||||
logger.info("Creating Root mandate")
|
||||
rootMandate = Mandate(
|
||||
|
|
@ -130,23 +131,20 @@ class AppObjects:
|
|||
language="en",
|
||||
enabled=True
|
||||
)
|
||||
createdMandate = self.db.recordCreate("mandates", rootMandate.to_dict())
|
||||
createdMandate = self.db.recordCreate(Mandate, rootMandate)
|
||||
logger.info(f"Root mandate created with ID {createdMandate['id']}")
|
||||
|
||||
# Register the initial ID
|
||||
self.db._registerInitialId("mandates", createdMandate['id'])
|
||||
|
||||
# Update mandate context
|
||||
self.mandateId = createdMandate['id']
|
||||
|
||||
def _initAdminUser(self):
|
||||
"""Creates the Admin user if it doesn't exist."""
|
||||
existingUserId = self.getInitialId("users")
|
||||
users = self.db.getRecordset("users")
|
||||
existingUserId = self.getInitialId(UserInDB)
|
||||
users = self.db.getRecordset(UserInDB)
|
||||
if existingUserId is None or not users:
|
||||
logger.info("Creating Admin user")
|
||||
adminUser = UserInDB(
|
||||
mandateId=self.getInitialId("mandates"),
|
||||
mandateId=self.getInitialId(Mandate),
|
||||
username="admin",
|
||||
email="admin@example.com",
|
||||
fullName="Administrator",
|
||||
|
|
@ -157,30 +155,27 @@ class AppObjects:
|
|||
hashedPassword=self._getPasswordHash("The 1st Poweron Admin"), # Use a secure password in production!
|
||||
connections=[]
|
||||
)
|
||||
createdUser = self.db.recordCreate("users", adminUser.to_dict())
|
||||
createdUser = self.db.recordCreate(UserInDB, adminUser)
|
||||
logger.info(f"Admin user created with ID {createdUser['id']}")
|
||||
|
||||
# Register the initial ID
|
||||
self.db._registerInitialId("users", createdUser['id'])
|
||||
|
||||
# Update user context
|
||||
self.currentUser = createdUser
|
||||
self.userId = createdUser.get("id")
|
||||
|
||||
def _uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified user access management function that filters data based on user privileges
|
||||
and adds access control attributes.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordset: Recordset to filter based on access rules
|
||||
|
||||
Returns:
|
||||
Filtered recordset with access control attributes
|
||||
"""
|
||||
# First apply access control
|
||||
filteredRecords = self.access.uam(table, recordset)
|
||||
filteredRecords = self.access.uam(model_class, recordset)
|
||||
|
||||
# Then filter out database-specific fields
|
||||
cleanedRecords = []
|
||||
|
|
@ -191,26 +186,23 @@ class AppObjects:
|
|||
|
||||
return cleanedRecords
|
||||
|
||||
def _canModify(self, table: str, recordId: Optional[str] = None) -> bool:
|
||||
def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Checks if the current user can modify (create/update/delete) records in a table.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordId: Optional record ID for specific record check
|
||||
|
||||
Returns:
|
||||
Boolean indicating permission
|
||||
"""
|
||||
return self.access.canModify(table, recordId)
|
||||
return self.access.canModify(model_class, recordId)
|
||||
|
||||
def _clearTableCache(self, table: str) -> None:
|
||||
"""Clears the cache for a specific table to ensure fresh data."""
|
||||
self.db.clearTableCache(table)
|
||||
|
||||
def getInitialId(self, table: str) -> Optional[str]:
|
||||
def getInitialId(self, model_class: type) -> Optional[str]:
|
||||
"""Returns the initial ID for a table."""
|
||||
return self.db.getInitialId(table)
|
||||
return self.db.getInitialId(model_class)
|
||||
|
||||
def _getPasswordHash(self, password: str) -> str:
|
||||
"""Creates a hash for a password."""
|
||||
|
|
@ -225,8 +217,8 @@ class AppObjects:
|
|||
def getUsersByMandate(self, mandateId: str) -> List[User]:
|
||||
"""Returns users for a specific mandate if user has access."""
|
||||
# Get users for this mandate
|
||||
users = self.db.getRecordset("users", recordFilter={"mandateId": mandateId})
|
||||
filteredUsers = self._uam("users", users)
|
||||
users = self.db.getRecordset(UserInDB, recordFilter={"mandateId": mandateId})
|
||||
filteredUsers = self._uam(UserInDB, users)
|
||||
|
||||
# Convert to User models
|
||||
return [User.from_dict(user) for user in filteredUsers]
|
||||
|
|
@ -235,7 +227,7 @@ class AppObjects:
|
|||
"""Returns a user by username."""
|
||||
try:
|
||||
# Get users table
|
||||
users = self.db.getRecordset("users")
|
||||
users = self.db.getRecordset(UserInDB)
|
||||
if not users:
|
||||
return None
|
||||
|
||||
|
|
@ -255,7 +247,7 @@ class AppObjects:
|
|||
"""Returns a user by ID if user has access."""
|
||||
try:
|
||||
# Get all users
|
||||
users = self.db.getRecordset("users")
|
||||
users = self.db.getRecordset(UserInDB)
|
||||
if not users:
|
||||
return None
|
||||
|
||||
|
|
@ -263,7 +255,7 @@ class AppObjects:
|
|||
for user_dict in users:
|
||||
if user_dict.get("id") == userId:
|
||||
# Apply access control
|
||||
filteredUsers = self._uam("users", [user_dict])
|
||||
filteredUsers = self._uam(UserInDB, [user_dict])
|
||||
if filteredUsers:
|
||||
return User.from_dict(filteredUsers[0])
|
||||
return None
|
||||
|
|
@ -278,7 +270,7 @@ class AppObjects:
|
|||
"""Returns all connections for a user."""
|
||||
try:
|
||||
# Get connections for this user
|
||||
connections = self.db.getRecordset("connections", recordFilter={"userId": userId})
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
|
||||
|
||||
# Convert to UserConnection objects
|
||||
result = []
|
||||
|
|
@ -345,10 +337,8 @@ class AppObjects:
|
|||
)
|
||||
|
||||
# Save to connections table
|
||||
self.db.recordCreate("connections", connection.to_dict())
|
||||
self.db.recordCreate(UserConnection, connection)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("connections")
|
||||
|
||||
return connection
|
||||
|
||||
|
|
@ -360,7 +350,7 @@ class AppObjects:
|
|||
"""Remove a connection to an external service"""
|
||||
try:
|
||||
# Get connection
|
||||
connections = self.db.getRecordset("connections", recordFilter={
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={
|
||||
"id": connectionId
|
||||
})
|
||||
|
||||
|
|
@ -368,10 +358,8 @@ class AppObjects:
|
|||
raise ValueError(f"Connection {connectionId} not found")
|
||||
|
||||
# Delete connection
|
||||
self.db.recordDelete("connections", connectionId)
|
||||
self.db.recordDelete(UserConnection, connectionId)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("connections")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing user connection: {str(e)}")
|
||||
|
|
@ -380,7 +368,6 @@ class AppObjects:
|
|||
def authenticateLocalUser(self, username: str, password: str) -> Optional[User]:
|
||||
"""Authenticates a user by username and password using local authentication."""
|
||||
# Clear the users table from cache and reload it
|
||||
self._clearTableCache("users")
|
||||
|
||||
# Get user by username
|
||||
user = self.getUserByUsername(username)
|
||||
|
|
@ -397,7 +384,7 @@ class AppObjects:
|
|||
raise ValueError("User does not have local authentication enabled")
|
||||
|
||||
# Get the full user record with password hash for verification
|
||||
userRecord = self.db.getRecordset("users", recordFilter={"id": user.id})[0]
|
||||
userRecord = self.db.getRecordset(UserInDB, recordFilter={"id": user.id})[0]
|
||||
if not userRecord.get("hashedPassword"):
|
||||
raise ValueError("User has no password set")
|
||||
|
||||
|
|
@ -441,12 +428,10 @@ class AppObjects:
|
|||
)
|
||||
|
||||
# Create user record
|
||||
createdRecord = self.db.recordCreate("users", userData.to_dict())
|
||||
createdRecord = self.db.recordCreate(UserInDB, userData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create user record")
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("users")
|
||||
|
||||
# Add external connection if provided
|
||||
if externalId and externalUsername:
|
||||
|
|
@ -459,12 +444,11 @@ class AppObjects:
|
|||
)
|
||||
|
||||
# Get created user using the returned ID
|
||||
createdUser = self.db.getRecordset("users", recordFilter={"id": createdRecord["id"]})
|
||||
createdUser = self.db.getRecordset(UserInDB, recordFilter={"id": createdRecord["id"]})
|
||||
if not createdUser or len(createdUser) == 0:
|
||||
raise ValueError("Failed to retrieve created user")
|
||||
|
||||
# Clear cache to ensure fresh data (already done above)
|
||||
# No need for additional cache clearing since _clearTableCache("users") was called
|
||||
|
||||
return User.from_dict(createdUser[0])
|
||||
|
||||
|
|
@ -489,10 +473,8 @@ class AppObjects:
|
|||
updatedUser = User.from_dict(updatedData)
|
||||
|
||||
# Update user record
|
||||
self.db.recordModify("users", userId, updatedUser.to_dict())
|
||||
self.db.recordModify(UserInDB, userId, updatedUser)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("users")
|
||||
|
||||
# Get updated user
|
||||
updatedUser = self.getUser(userId)
|
||||
|
|
@ -519,20 +501,20 @@ class AppObjects:
|
|||
|
||||
|
||||
# Delete user auth events
|
||||
events = self.db.getRecordset("auth_events", recordFilter={"userId": userId})
|
||||
events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
|
||||
for event in events:
|
||||
self.db.recordDelete("auth_events", event["id"])
|
||||
self.db.recordDelete(AuthEvent, event["id"])
|
||||
|
||||
# Delete user tokens
|
||||
tokens = self.db.getRecordset("tokens", recordFilter={"userId": userId})
|
||||
tokens = self.db.getRecordset(Token, recordFilter={"userId": userId})
|
||||
for token in tokens:
|
||||
self.db.recordDelete("tokens", token["id"])
|
||||
self.db.recordDelete(Token, token["id"])
|
||||
|
||||
|
||||
# Delete user connections
|
||||
connections = self.db.getRecordset("connections", recordFilter={"userId": userId})
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
|
||||
for conn in connections:
|
||||
self.db.recordDelete("connections", conn["id"])
|
||||
self.db.recordDelete(UserConnection, conn["id"])
|
||||
|
||||
logger.info(f"All referenced data for user {userId} has been deleted")
|
||||
|
||||
|
|
@ -548,19 +530,17 @@ class AppObjects:
|
|||
if not user:
|
||||
raise ValueError(f"User {userId} not found")
|
||||
|
||||
if not self._canModify("users", userId):
|
||||
if not self._canModify(UserInDB, userId):
|
||||
raise PermissionError(f"No permission to delete user {userId}")
|
||||
|
||||
# Delete all referenced data first
|
||||
self._deleteUserReferencedData(userId)
|
||||
|
||||
# Delete user record
|
||||
success = self.db.recordDelete("users", userId)
|
||||
success = self.db.recordDelete(UserInDB, userId)
|
||||
if not success:
|
||||
raise ValueError(f"Failed to delete user {userId}")
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("users")
|
||||
|
||||
logger.info(f"User {userId} successfully deleted")
|
||||
return True
|
||||
|
|
@ -573,17 +553,17 @@ class AppObjects:
|
|||
|
||||
def getAllMandates(self) -> List[Mandate]:
|
||||
"""Returns all mandates based on user access level."""
|
||||
allMandates = self.db.getRecordset("mandates")
|
||||
filteredMandates = self._uam("mandates", allMandates)
|
||||
allMandates = self.db.getRecordset(Mandate)
|
||||
filteredMandates = self._uam(Mandate, allMandates)
|
||||
return [Mandate.from_dict(mandate) for mandate in filteredMandates]
|
||||
|
||||
def getMandate(self, mandateId: str) -> Optional[Mandate]:
|
||||
"""Returns a mandate by ID if user has access."""
|
||||
mandates = self.db.getRecordset("mandates", recordFilter={"id": mandateId})
|
||||
mandates = self.db.getRecordset(Mandate, recordFilter={"id": mandateId})
|
||||
if not mandates:
|
||||
return None
|
||||
|
||||
filteredMandates = self._uam("mandates", mandates)
|
||||
filteredMandates = self._uam(Mandate, mandates)
|
||||
if not filteredMandates:
|
||||
return None
|
||||
|
||||
|
|
@ -591,7 +571,7 @@ class AppObjects:
|
|||
|
||||
def createMandate(self, name: str, language: str = "en") -> Mandate:
|
||||
"""Creates a new mandate if user has permission."""
|
||||
if not self._canModify("mandates"):
|
||||
if not self._canModify(Mandate):
|
||||
raise PermissionError("No permission to create mandates")
|
||||
|
||||
# Create mandate data using model
|
||||
|
|
@ -601,12 +581,10 @@ class AppObjects:
|
|||
)
|
||||
|
||||
# Create mandate record
|
||||
createdRecord = self.db.recordCreate("mandates", mandateData.to_dict())
|
||||
createdRecord = self.db.recordCreate(Mandate, mandateData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create mandate record")
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("mandates")
|
||||
|
||||
return Mandate.from_dict(createdRecord)
|
||||
|
||||
|
|
@ -614,7 +592,7 @@ class AppObjects:
|
|||
"""Updates a mandate if user has access."""
|
||||
try:
|
||||
# First check if user has permission to modify mandates
|
||||
if not self._canModify("mandates", mandateId):
|
||||
if not self._canModify(Mandate, mandateId):
|
||||
raise PermissionError(f"No permission to update mandate {mandateId}")
|
||||
|
||||
# Get mandate with access control
|
||||
|
|
@ -628,10 +606,9 @@ class AppObjects:
|
|||
updatedMandate = Mandate.from_dict(updatedData)
|
||||
|
||||
# Update mandate record
|
||||
self.db.recordModify("mandates", mandateId, updatedMandate.to_dict())
|
||||
self.db.recordModify(Mandate, mandateId, updatedMandate)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("mandates")
|
||||
|
||||
# Get updated mandate
|
||||
updatedMandate = self.getMandate(mandateId)
|
||||
|
|
@ -652,7 +629,7 @@ class AppObjects:
|
|||
if not mandate:
|
||||
return False
|
||||
|
||||
if not self._canModify("mandates", mandateId):
|
||||
if not self._canModify(Mandate, mandateId):
|
||||
raise PermissionError(f"No permission to delete mandate {mandateId}")
|
||||
|
||||
# Check if mandate has users
|
||||
|
|
@ -661,10 +638,9 @@ class AppObjects:
|
|||
raise ValueError(f"Cannot delete mandate {mandateId} with existing users")
|
||||
|
||||
# Delete mandate
|
||||
success = self.db.recordDelete("mandates", mandateId)
|
||||
success = self.db.recordDelete(Mandate, mandateId)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("mandates")
|
||||
|
||||
return success
|
||||
|
||||
|
|
@ -675,11 +651,11 @@ class AppObjects:
|
|||
def _getInitialUser(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the initial user record directly from database without access control."""
|
||||
try:
|
||||
initialUserId = self.db.getInitialId("users")
|
||||
initialUserId = self.getInitialId(UserInDB)
|
||||
if not initialUserId:
|
||||
return None
|
||||
|
||||
users = self.db.getRecordset("users", recordFilter={"id": initialUserId})
|
||||
users = self.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
|
||||
return users[0] if users else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting initial user: {str(e)}")
|
||||
|
|
@ -742,7 +718,7 @@ class AppObjects:
|
|||
# If replace_existing is True, delete old access tokens for this user and authority first
|
||||
if replace_existing:
|
||||
try:
|
||||
old_tokens = self.db.getRecordset("tokens", recordFilter={
|
||||
old_tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"userId": self.currentUser.id,
|
||||
"authority": token.authority,
|
||||
"connectionId": None # Ensure we only delete access tokens
|
||||
|
|
@ -750,7 +726,7 @@ class AppObjects:
|
|||
deleted_count = 0
|
||||
for old_token in old_tokens:
|
||||
if old_token["id"] != token.id: # Don't delete the new token if it already exists
|
||||
self.db.recordDelete("tokens", old_token["id"])
|
||||
self.db.recordDelete(Token, old_token["id"])
|
||||
deleted_count += 1
|
||||
logger.debug(f"Deleted old access token {old_token['id']} for user {self.currentUser.id} and authority {token.authority}")
|
||||
|
||||
|
|
@ -767,10 +743,8 @@ class AppObjects:
|
|||
token_dict["userId"] = self.currentUser.id
|
||||
|
||||
# Save to database
|
||||
self.db.recordCreate("tokens", token_dict)
|
||||
self.db.recordCreate(Token, token_dict)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("tokens")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving access token: {str(e)}")
|
||||
|
|
@ -799,13 +773,13 @@ class AppObjects:
|
|||
# If replace_existing is True, delete old tokens for this connectionId first
|
||||
if replace_existing:
|
||||
try:
|
||||
old_tokens = self.db.getRecordset("tokens", recordFilter={
|
||||
old_tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": token.connectionId
|
||||
})
|
||||
deleted_count = 0
|
||||
for old_token in old_tokens:
|
||||
if old_token["id"] != token.id: # Don't delete the new token if it already exists
|
||||
self.db.recordDelete("tokens", old_token["id"])
|
||||
self.db.recordDelete(Token, old_token["id"])
|
||||
deleted_count += 1
|
||||
logger.debug(f"Deleted old token {old_token['id']} for connectionId {token.connectionId}")
|
||||
|
||||
|
|
@ -822,10 +796,8 @@ class AppObjects:
|
|||
token_dict["userId"] = self.currentUser.id
|
||||
|
||||
# Save to database
|
||||
self.db.recordCreate("tokens", token_dict)
|
||||
self.db.recordCreate(Token, token_dict)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("tokens")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving connection token: {str(e)}")
|
||||
|
|
@ -839,7 +811,7 @@ class AppObjects:
|
|||
raise ValueError("No valid user context available for token retrieval")
|
||||
|
||||
# Get access tokens for this user and authority (must NOT have connectionId)
|
||||
tokens = self.db.getRecordset("tokens", recordFilter={
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"userId": self.currentUser.id,
|
||||
"authority": authority,
|
||||
"connectionId": None # Ensure we only get access tokens
|
||||
|
|
@ -888,7 +860,7 @@ class AppObjects:
|
|||
|
||||
# Get token for this specific connection
|
||||
# Query for specific connection
|
||||
tokens = self.db.getRecordset("tokens", recordFilter={
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": connectionId
|
||||
})
|
||||
|
||||
|
|
@ -899,7 +871,7 @@ class AppObjects:
|
|||
logger.debug(f"getConnectionToken: Token {i}: id={token.get('id')}, expiresAt={token.get('expiresAt')}, createdAt={token.get('createdAt')}")
|
||||
else:
|
||||
# Debug: Check if there are any tokens at all in the database
|
||||
all_tokens = self.db.getRecordset("tokens", recordFilter={})
|
||||
all_tokens = self.db.getRecordset(Token, recordFilter={})
|
||||
logger.debug(f"getConnectionToken: No tokens found for connectionId {connectionId}. Total tokens in database: {len(all_tokens)}")
|
||||
if all_tokens:
|
||||
logger.debug(f"getConnectionToken: Sample tokens: {[{'id': t.get('id'), 'connectionId': t.get('connectionId'), 'authority': t.get('authority')} for t in all_tokens[:3]]}")
|
||||
|
|
@ -956,7 +928,7 @@ class AppObjects:
|
|||
raise ValueError("No valid user context available for token deletion")
|
||||
|
||||
# Get access tokens to delete (must NOT have connectionId)
|
||||
tokens = self.db.getRecordset("tokens", recordFilter={
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"userId": self.currentUser.id,
|
||||
"authority": authority,
|
||||
"connectionId": None # Ensure we only delete access tokens
|
||||
|
|
@ -964,10 +936,8 @@ class AppObjects:
|
|||
|
||||
# Delete each token
|
||||
for token in tokens:
|
||||
self.db.recordDelete("tokens", token["id"])
|
||||
self.db.recordDelete(Token, token["id"])
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("tokens")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting access token: {str(e)}")
|
||||
|
|
@ -981,16 +951,14 @@ class AppObjects:
|
|||
raise ValueError("connectionId is required for deleteConnectionTokenByConnectionId")
|
||||
|
||||
# Get connection tokens to delete
|
||||
tokens = self.db.getRecordset("tokens", recordFilter={
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": connectionId
|
||||
})
|
||||
|
||||
# Delete each token
|
||||
for token in tokens:
|
||||
self.db.recordDelete("tokens", token["id"])
|
||||
self.db.recordDelete(Token, token["id"])
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("tokens")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting connection token for connectionId {connectionId}: {str(e)}")
|
||||
|
|
@ -1005,17 +973,16 @@ class AppObjects:
|
|||
cleaned_count = 0
|
||||
|
||||
# Get all tokens
|
||||
all_tokens = self.db.getRecordset("tokens", recordFilter={})
|
||||
all_tokens = self.db.getRecordset(Token, recordFilter={})
|
||||
|
||||
for token_data in all_tokens:
|
||||
if token_data.get("expiresAt") and token_data.get("expiresAt") < current_time:
|
||||
# Token is expired, delete it
|
||||
self.db.recordDelete("tokens", token_data["id"])
|
||||
self.db.recordDelete(Token, token_data["id"])
|
||||
cleaned_count += 1
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
if cleaned_count > 0:
|
||||
self._clearTableCache("tokens")
|
||||
logger.info(f"Cleaned up {cleaned_count} expired tokens")
|
||||
|
||||
return cleaned_count
|
||||
|
|
@ -1072,16 +1039,23 @@ def getRootUser() -> User:
|
|||
tempInterface = AppObjects()
|
||||
|
||||
# Get the initial user directly
|
||||
initialUserId = tempInterface.db.getInitialId("users")
|
||||
initialUserId = tempInterface.getInitialId(UserInDB)
|
||||
if not initialUserId:
|
||||
raise ValueError("No initial user ID found in database")
|
||||
|
||||
users = tempInterface.db.getRecordset("users", recordFilter={"id": initialUserId})
|
||||
users = tempInterface.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
|
||||
if not users:
|
||||
raise ValueError("Initial user not found in database")
|
||||
|
||||
|
||||
logger.debug(f"Retrieved user data: {users[0]}")
|
||||
|
||||
# Convert to User model and return the model instance
|
||||
return User.from_dict(users[0])
|
||||
user_data = users[0]
|
||||
logger.debug(f"User data keys: {list(user_data.keys())}")
|
||||
logger.debug(f"User id: {user_data.get('id')}")
|
||||
logger.debug(f"User mandateId: {user_data.get('mandateId')}")
|
||||
|
||||
return User.parse_obj(user_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting root user: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Handles user access management and permission checks.
|
|||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.interfaces.interfaceAppModel import User, UserPrivilege
|
||||
from modules.interfaces.interfaceChatModel import ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument
|
||||
|
||||
class ChatAccess:
|
||||
"""
|
||||
|
|
@ -23,19 +24,20 @@ class ChatAccess:
|
|||
|
||||
self.db = db
|
||||
|
||||
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified user access management function that filters data based on user privileges
|
||||
and adds access control attributes.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordset: Recordset to filter based on access rules
|
||||
|
||||
Returns:
|
||||
Filtered recordset with access control attributes
|
||||
"""
|
||||
userPrivilege = self.currentUser.privilege
|
||||
table_name = model_class.__name__
|
||||
filtered_records = []
|
||||
|
||||
# Apply filtering based on privilege
|
||||
|
|
@ -54,32 +56,32 @@ class ChatAccess:
|
|||
record_id = record.get("id")
|
||||
|
||||
# Set access control flags based on user permissions
|
||||
if table == "workflows":
|
||||
if table_name == "ChatWorkflow":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("workflows", record_id)
|
||||
record["_hideDelete"] = not self.canModify("workflows", record_id)
|
||||
elif table == "workflowMessages":
|
||||
record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id)
|
||||
record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id)
|
||||
elif table_name == "ChatMessage":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
elif table == "workflowLogs":
|
||||
record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
elif table_name == "ChatLog":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
else:
|
||||
# Default access control for other tables
|
||||
record["_hideView"] = False
|
||||
record["_hideEdit"] = not self.canModify(table, record_id)
|
||||
record["_hideDelete"] = not self.canModify(table, record_id)
|
||||
record["_hideEdit"] = not self.canModify(model_class, record_id)
|
||||
record["_hideDelete"] = not self.canModify(model_class, record_id)
|
||||
|
||||
return filtered_records
|
||||
|
||||
def canModify(self, table: str, recordId: Optional[str] = None) -> bool:
|
||||
def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Checks if the current user can modify (create/update/delete) records in a table.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordId: Optional record ID for specific record check
|
||||
|
||||
Returns:
|
||||
|
|
@ -94,7 +96,7 @@ class ChatAccess:
|
|||
# For regular users and admins, check specific cases
|
||||
if recordId is not None:
|
||||
# Get the record to check ownership
|
||||
records: List[Dict[str, Any]] = self.db.getRecordset(table, recordFilter={"id": recordId})
|
||||
records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId})
|
||||
if not records:
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ register_model_labels(
|
|||
class ChatDocument(BaseModel, ModelMixin):
|
||||
"""Data model for a chat document"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
messageId: str = Field(description="Foreign key to message")
|
||||
fileId: str = Field(description="Foreign key to file")
|
||||
|
||||
# Direct file attributes (copied from file object)
|
||||
|
|
@ -197,6 +198,7 @@ register_model_labels(
|
|||
{"en": "Chat Document", "fr": "Document de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"messageId": {"en": "Message ID", "fr": "ID du message"},
|
||||
"fileId": {"en": "File ID", "fr": "ID du fichier"},
|
||||
"roundNumber": {"en": "Round Number", "fr": "Numéro de tour"},
|
||||
"taskNumber": {"en": "Task Number", "fr": "Numéro de tâche"},
|
||||
|
|
@ -400,6 +402,8 @@ register_model_labels(
|
|||
class ChatStat(BaseModel, ModelMixin):
|
||||
"""Data model for chat statistics - ONLY statistics, not workflow progress"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
|
||||
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
|
||||
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
||||
|
|
@ -413,6 +417,8 @@ register_model_labels(
|
|||
{"en": "Chat Statistics", "fr": "Statistiques de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
|
||||
"messageId": {"en": "Message ID", "fr": "ID du message"},
|
||||
"processingTime": {"en": "Processing Time", "fr": "Temps de traitement"},
|
||||
"tokenCount": {"en": "Token Count", "fr": "Nombre de tokens"},
|
||||
"bytesSent": {"en": "Bytes Sent", "fr": "Octets envoyés"},
|
||||
|
|
@ -650,8 +656,8 @@ register_model_labels(
|
|||
class TaskStep(BaseModel, ModelMixin):
|
||||
id: str
|
||||
objective: str
|
||||
dependencies: Optional[list[str]] = []
|
||||
success_criteria: Optional[list[str]] = []
|
||||
dependencies: Optional[list[str]] = Field(default_factory=list)
|
||||
success_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
estimated_complexity: Optional[str] = None
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
|
||||
|
|
@ -733,23 +739,23 @@ class TaskContext(BaseModel, ModelMixin):
|
|||
|
||||
# Available resources
|
||||
available_documents: Optional[str] = "No documents available"
|
||||
available_connections: Optional[list[str]] = []
|
||||
available_connections: Optional[list[str]] = Field(default_factory=list)
|
||||
|
||||
# Previous execution state
|
||||
previous_results: Optional[list[str]] = []
|
||||
previous_results: Optional[list[str]] = Field(default_factory=list)
|
||||
previous_handover: Optional[TaskHandover] = None
|
||||
|
||||
# Current execution state
|
||||
improvements: Optional[list[str]] = []
|
||||
improvements: Optional[list[str]] = Field(default_factory=list)
|
||||
retry_count: Optional[int] = 0
|
||||
previous_action_results: Optional[list] = []
|
||||
previous_action_results: Optional[list] = Field(default_factory=list)
|
||||
previous_review_result: Optional[dict] = None
|
||||
is_regeneration: Optional[bool] = False
|
||||
|
||||
# Failure analysis
|
||||
failure_patterns: Optional[list[str]] = []
|
||||
failed_actions: Optional[list] = []
|
||||
successful_actions: Optional[list] = []
|
||||
failure_patterns: Optional[list[str]] = Field(default_factory=list)
|
||||
failed_actions: Optional[list] = Field(default_factory=list)
|
||||
successful_actions: Optional[list] = Field(default_factory=list)
|
||||
|
||||
# Criteria progress tracking for retries
|
||||
criteria_progress: Optional[dict] = None
|
||||
|
|
@ -771,20 +777,20 @@ class TaskContext(BaseModel, ModelMixin):
|
|||
|
||||
class ReviewContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
task_actions: Optional[list] = []
|
||||
action_results: Optional[list] = []
|
||||
step_result: Optional[dict] = {}
|
||||
task_actions: Optional[list] = Field(default_factory=list)
|
||||
action_results: Optional[list] = Field(default_factory=list)
|
||||
step_result: Optional[dict] = Field(default_factory=dict)
|
||||
workflow_id: Optional[str] = None
|
||||
previous_results: Optional[list[str]] = []
|
||||
previous_results: Optional[list[str]] = Field(default_factory=list)
|
||||
|
||||
class ReviewResult(BaseModel, ModelMixin):
|
||||
status: str
|
||||
reason: Optional[str] = None
|
||||
improvements: Optional[list[str]] = []
|
||||
improvements: Optional[list[str]] = Field(default_factory=list)
|
||||
quality_score: Optional[int] = 5
|
||||
missing_outputs: Optional[list[str]] = []
|
||||
met_criteria: Optional[list[str]] = []
|
||||
unmet_criteria: Optional[list[str]] = []
|
||||
missing_outputs: Optional[list[str]] = Field(default_factory=list)
|
||||
met_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
confidence: Optional[float] = 0.5
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -5,7 +5,9 @@ Handles user access management and permission checks.
|
|||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
from modules.interfaces.interfaceAppModel import User, UserInDB
|
||||
from modules.interfaces.interfaceComponentModel import Prompt, FileItem, FileData
|
||||
from modules.interfaces.interfaceChatModel import ChatWorkflow, ChatMessage, ChatLog
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -47,19 +49,20 @@ class ComponentAccess:
|
|||
|
||||
return True
|
||||
|
||||
def uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified user access management function that filters data based on user privileges
|
||||
and adds access control attributes.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordset: Recordset to filter based on access rules
|
||||
|
||||
Returns:
|
||||
Filtered recordset with access control attributes
|
||||
"""
|
||||
userPrivilege = self.privilege
|
||||
table_name = model_class.__name__
|
||||
|
||||
filtered_records = []
|
||||
|
||||
|
|
@ -73,9 +76,9 @@ class ComponentAccess:
|
|||
filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId]
|
||||
else: # Regular users
|
||||
# For prompts, users can see all prompts from their mandate
|
||||
if table == "prompts":
|
||||
if table_name == "Prompt":
|
||||
filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId]
|
||||
elif table == "users":
|
||||
elif table_name == "UserInDB":
|
||||
# For users table, users can only see their own record
|
||||
filtered_records = [r for r in recordset if r.get("id") == self.userId]
|
||||
else:
|
||||
|
|
@ -90,32 +93,32 @@ class ComponentAccess:
|
|||
record_id = record.get("id")
|
||||
|
||||
# Set access control flags based on user permissions
|
||||
if table == "prompts":
|
||||
if table_name == "Prompt":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("prompts", record_id)
|
||||
record["_hideDelete"] = not self.canModify("prompts", record_id)
|
||||
record["_hideEdit"] = not self.canModify(Prompt, record_id)
|
||||
record["_hideDelete"] = not self.canModify(Prompt, record_id)
|
||||
|
||||
# Add attribute-level permissions for mandateId
|
||||
if "mandateId" in record:
|
||||
record["_hideEdit_mandateId"] = not self.canModifyAttribute("prompts", "mandateId")
|
||||
elif table == "files":
|
||||
record["_hideEdit_mandateId"] = not self.canModifyAttribute(Prompt, "mandateId")
|
||||
elif table_name == "FileItem":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("files", record_id)
|
||||
record["_hideDelete"] = not self.canModify("files", record_id)
|
||||
record["_hideDownload"] = not self.canModify("files", record_id)
|
||||
elif table == "workflows":
|
||||
record["_hideEdit"] = not self.canModify(FileItem, record_id)
|
||||
record["_hideDelete"] = not self.canModify(FileItem, record_id)
|
||||
record["_hideDownload"] = not self.canModify(FileItem, record_id)
|
||||
elif table_name == "ChatWorkflow":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("workflows", record_id)
|
||||
record["_hideDelete"] = not self.canModify("workflows", record_id)
|
||||
elif table == "workflowMessages":
|
||||
record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id)
|
||||
record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id)
|
||||
elif table_name == "ChatMessage":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
elif table == "workflowLogs":
|
||||
record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
elif table_name == "ChatLog":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify("workflows", record.get("workflowId"))
|
||||
elif table == "users":
|
||||
record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
elif table_name == "UserInDB":
|
||||
# For users table, users can only modify their own connections
|
||||
record["_hideView"] = False
|
||||
record["_hideEdit"] = record_id != self.userId
|
||||
|
|
@ -128,17 +131,17 @@ class ComponentAccess:
|
|||
else:
|
||||
# Default access control for other tables
|
||||
record["_hideView"] = False
|
||||
record["_hideEdit"] = not self.canModify(table, record_id)
|
||||
record["_hideDelete"] = not self.canModify(table, record_id)
|
||||
record["_hideEdit"] = not self.canModify(model_class, record_id)
|
||||
record["_hideDelete"] = not self.canModify(model_class, record_id)
|
||||
|
||||
return filtered_records
|
||||
|
||||
def canModify(self, table: str, recordId: Optional[int] = None) -> bool:
|
||||
def canModify(self, model_class: type, recordId: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Checks if the current user can modify (create/update/delete) records in a table.
|
||||
|
||||
Args:
|
||||
table: Name of the table
|
||||
model_class: Pydantic model class for the table
|
||||
recordId: Optional record ID for specific record check
|
||||
|
||||
Returns:
|
||||
|
|
@ -153,14 +156,14 @@ class ComponentAccess:
|
|||
# For regular users and admins, check specific cases
|
||||
if recordId is not None:
|
||||
# Get the record to check ownership
|
||||
records: List[Dict[str, Any]] = self.db.getRecordset(table, recordFilter={"id": recordId})
|
||||
records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId})
|
||||
if not records:
|
||||
return False
|
||||
|
||||
record = records[0]
|
||||
|
||||
# Special case for users table - users can modify their own connections
|
||||
if table == "users":
|
||||
if model_class.__name__ == "UserInDB":
|
||||
if record.get("id") == self.userId:
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -14,11 +14,10 @@ from modules.interfaces.interfaceComponentAccess import ComponentAccess
|
|||
from modules.interfaces.interfaceComponentModel import (
|
||||
FilePreview, Prompt, FileItem, FileData
|
||||
)
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
from modules.interfaces.interfaceAppModel import User, Mandate
|
||||
|
||||
# DYNAMIC PART: Connectors to the Interface
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.connectors.connectorPool import get_connector, return_connector
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
|
||||
# Basic Configurations
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
|
@ -90,35 +89,38 @@ class ComponentObjects:
|
|||
self.db.updateContext(self.userId)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to return connector to pool."""
|
||||
"""Cleanup method to close database connection."""
|
||||
if hasattr(self, 'db') and self.db is not None:
|
||||
try:
|
||||
return_connector(self.db)
|
||||
self.db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error returning connector to pool: {e}")
|
||||
logger.error(f"Error closing database connection: {e}")
|
||||
|
||||
logger.debug(f"User context set: userId={self.userId}")
|
||||
|
||||
def _initializeDatabase(self):
|
||||
"""Initializes the database connection."""
|
||||
"""Initializes the database connection directly."""
|
||||
try:
|
||||
# Get configuration values with defaults
|
||||
dbHost = APP_CONFIG.get("DB_MANAGEMENT_HOST", "_no_config_default_data")
|
||||
dbDatabase = APP_CONFIG.get("DB_MANAGEMENT_DATABASE", "management")
|
||||
dbUser = APP_CONFIG.get("DB_MANAGEMENT_USER")
|
||||
dbPassword = APP_CONFIG.get("DB_MANAGEMENT_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_MANAGEMENT_PORT"))
|
||||
|
||||
# Ensure the database directory exists
|
||||
os.makedirs(dbHost, exist_ok=True)
|
||||
|
||||
self.db = get_connector(
|
||||
# Create database connector directly
|
||||
self.db = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.userId if hasattr(self, 'userId') else None
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
self.db.initDbSystem()
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
|
|
@ -142,7 +144,7 @@ class ComponentObjects:
|
|||
"""Initializes standard prompts if they don't exist yet."""
|
||||
try:
|
||||
# Check if any prompts exist
|
||||
existingPrompts = self.db.getRecordset("prompts")
|
||||
existingPrompts = self.db.getRecordset(Prompt)
|
||||
if existingPrompts:
|
||||
logger.info("Prompts already exist, skipping initialization")
|
||||
return
|
||||
|
|
@ -152,7 +154,7 @@ class ComponentObjects:
|
|||
rootInterface = getRootInterface()
|
||||
|
||||
# Get initial mandate ID through the root interface
|
||||
mandateId = rootInterface.getInitialId("mandates")
|
||||
mandateId = rootInterface.getInitialId(Mandate)
|
||||
if not mandateId:
|
||||
logger.error("No initial mandate ID found")
|
||||
return
|
||||
|
|
@ -205,7 +207,7 @@ class ComponentObjects:
|
|||
|
||||
# Create prompts
|
||||
for prompt in standardPrompts:
|
||||
self.db.recordCreate("prompts", prompt.to_dict())
|
||||
self.db.recordCreate(Prompt, prompt)
|
||||
logger.info(f"Created standard prompt: {prompt.name}")
|
||||
|
||||
# Restore original user context if it existed
|
||||
|
|
@ -228,10 +230,10 @@ class ComponentObjects:
|
|||
self.access = None
|
||||
self.db.updateContext("") # Reset database context
|
||||
|
||||
def _uam(self, table: str, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Delegate to access control module."""
|
||||
# First apply access control
|
||||
filteredRecords = self.access.uam(table, recordset)
|
||||
filteredRecords = self.access.uam(model_class, recordset)
|
||||
|
||||
# Then filter out database-specific fields
|
||||
cleanedRecords = []
|
||||
|
|
@ -242,19 +244,16 @@ class ComponentObjects:
|
|||
|
||||
return cleanedRecords
|
||||
|
||||
def _canModify(self, table: str, recordId: Optional[str] = None) -> bool:
|
||||
def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
|
||||
"""Delegate to access control module."""
|
||||
return self.access.canModify(table, recordId)
|
||||
return self.access.canModify(model_class, recordId)
|
||||
|
||||
def _clearTableCache(self, table: str) -> None:
|
||||
"""Clears the cache for a specific table to ensure fresh data."""
|
||||
self.db.clearTableCache(table)
|
||||
|
||||
# Utilities
|
||||
|
||||
def getInitialId(self, table: str) -> Optional[str]:
|
||||
def getInitialId(self, model_class: type) -> Optional[str]:
|
||||
"""Returns the initial ID for a table."""
|
||||
return self.db.getInitialId(table)
|
||||
return self.db.getInitialId(model_class)
|
||||
|
||||
|
||||
|
||||
|
|
@ -263,8 +262,8 @@ class ComponentObjects:
|
|||
def getAllPrompts(self) -> List[Prompt]:
|
||||
"""Returns prompts based on user access level."""
|
||||
try:
|
||||
allPrompts = self.db.getRecordset("prompts")
|
||||
filteredPrompts = self._uam("prompts", allPrompts)
|
||||
allPrompts = self.db.getRecordset(Prompt)
|
||||
filteredPrompts = self._uam(Prompt, allPrompts)
|
||||
|
||||
# Convert to Prompt objects
|
||||
return [Prompt.from_dict(prompt) for prompt in filteredPrompts]
|
||||
|
|
@ -275,25 +274,23 @@ class ComponentObjects:
|
|||
|
||||
def getPrompt(self, promptId: str) -> Optional[Prompt]:
|
||||
"""Returns a prompt by ID if user has access."""
|
||||
prompts = self.db.getRecordset("prompts", recordFilter={"id": promptId})
|
||||
prompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
|
||||
if not prompts:
|
||||
return None
|
||||
|
||||
filteredPrompts = self._uam("prompts", prompts)
|
||||
filteredPrompts = self._uam(Prompt, prompts)
|
||||
return Prompt.from_dict(filteredPrompts[0]) if filteredPrompts else None
|
||||
|
||||
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Creates a new prompt if user has permission."""
|
||||
if not self._canModify("prompts"):
|
||||
if not self._canModify(Prompt):
|
||||
raise PermissionError("No permission to create prompts")
|
||||
|
||||
# Create prompt record
|
||||
createdRecord = self.db.recordCreate("prompts", promptData)
|
||||
# Create prompt record
|
||||
createdRecord = self.db.recordCreate(Prompt, promptData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create prompt record")
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("prompts")
|
||||
|
||||
return createdRecord
|
||||
|
||||
|
|
@ -306,10 +303,9 @@ class ComponentObjects:
|
|||
raise ValueError(f"Prompt {promptId} not found")
|
||||
|
||||
# Update prompt record directly with the update data
|
||||
self.db.recordModify("prompts", promptId, updateData)
|
||||
self.db.recordModify(Prompt, promptId, updateData)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("prompts")
|
||||
|
||||
# Get updated prompt
|
||||
updatedPrompt = self.getPrompt(promptId)
|
||||
|
|
@ -329,14 +325,12 @@ class ComponentObjects:
|
|||
if not prompt:
|
||||
return False
|
||||
|
||||
if not self._canModify("prompts", promptId):
|
||||
if not self._canModify(Prompt, promptId):
|
||||
raise PermissionError(f"No permission to delete prompt {promptId}")
|
||||
|
||||
# Delete prompt
|
||||
success = self.db.recordDelete("prompts", promptId)
|
||||
success = self.db.recordDelete(Prompt, promptId)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("prompts")
|
||||
|
||||
return success
|
||||
|
||||
|
|
@ -347,12 +341,12 @@ class ComponentObjects:
|
|||
If fileName is provided, also checks for exact name+hash match.
|
||||
Only returns files the current user has access to."""
|
||||
# First get all files with the hash
|
||||
allFilesWithHash = self.db.getRecordset("files", recordFilter={
|
||||
allFilesWithHash = self.db.getRecordset(FileItem, recordFilter={
|
||||
"fileHash": fileHash
|
||||
})
|
||||
|
||||
# Filter by user access using UAM
|
||||
accessibleFiles = self._uam("files", allFilesWithHash)
|
||||
accessibleFiles = self._uam(FileItem, allFilesWithHash)
|
||||
|
||||
if not accessibleFiles:
|
||||
return None
|
||||
|
|
@ -468,8 +462,8 @@ class ComponentObjects:
|
|||
|
||||
def getAllFiles(self) -> List[FileItem]:
|
||||
"""Returns files based on user access level."""
|
||||
allFiles = self.db.getRecordset("files")
|
||||
filteredFiles = self._uam("files", allFiles)
|
||||
allFiles = self.db.getRecordset(FileItem)
|
||||
filteredFiles = self._uam(FileItem, allFiles)
|
||||
|
||||
# Convert database records to FileItem instances
|
||||
fileItems = []
|
||||
|
|
@ -502,11 +496,11 @@ class ComponentObjects:
|
|||
|
||||
def getFile(self, fileId: str) -> Optional[FileItem]:
|
||||
"""Returns a file by ID if user has access."""
|
||||
files = self.db.getRecordset("files", recordFilter={"id": fileId})
|
||||
files = self.db.getRecordset(FileItem, recordFilter={"id": fileId})
|
||||
if not files:
|
||||
return None
|
||||
|
||||
filteredFiles = self._uam("files", files)
|
||||
filteredFiles = self._uam(FileItem, files)
|
||||
if not filteredFiles:
|
||||
return None
|
||||
|
||||
|
|
@ -534,7 +528,7 @@ class ComponentObjects:
|
|||
def _isfileNameUnique(self, fileName: str, excludeFileId: Optional[str] = None) -> bool:
|
||||
"""Checks if a fileName is unique for the current user."""
|
||||
# Get all files for current user
|
||||
files = self.db.getRecordset("files", recordFilter={
|
||||
files = self.db.getRecordset(FileItem, recordFilter={
|
||||
"_createdBy": self.currentUser.id
|
||||
})
|
||||
|
||||
|
|
@ -566,7 +560,7 @@ class ComponentObjects:
|
|||
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
|
||||
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content."""
|
||||
import hashlib
|
||||
if not self._canModify("files"):
|
||||
if not self._canModify(FileItem):
|
||||
raise PermissionError("No permission to create files")
|
||||
|
||||
# Ensure fileName is unique
|
||||
|
|
@ -589,10 +583,8 @@ class ComponentObjects:
|
|||
)
|
||||
|
||||
# Store in database
|
||||
self.db.recordCreate("files", fileItem.to_dict())
|
||||
self.db.recordCreate(FileItem, fileItem)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("files")
|
||||
|
||||
return fileItem
|
||||
|
||||
|
|
@ -603,7 +595,7 @@ class ComponentObjects:
|
|||
if not file:
|
||||
raise FileNotFoundError(f"File with ID {fileId} not found")
|
||||
|
||||
if not self._canModify("files", fileId):
|
||||
if not self._canModify(FileItem, fileId):
|
||||
raise PermissionError(f"No permission to update file {fileId}")
|
||||
|
||||
# If fileName is being updated, ensure it's unique
|
||||
|
|
@ -611,10 +603,8 @@ class ComponentObjects:
|
|||
updateData["fileName"] = self._generateUniquefileName(updateData["fileName"], fileId)
|
||||
|
||||
# Update file
|
||||
success = self.db.recordModify("files", fileId, updateData)
|
||||
success = self.db.recordModify(FileItem, fileId, updateData)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("files")
|
||||
|
||||
return success
|
||||
|
||||
|
|
@ -627,30 +617,29 @@ class ComponentObjects:
|
|||
if not file:
|
||||
raise FileNotFoundError(f"File with ID {fileId} not found")
|
||||
|
||||
if not self._canModify("files", fileId):
|
||||
if not self._canModify(FileItem, fileId):
|
||||
raise PermissionError(f"No permission to delete file {fileId}")
|
||||
|
||||
# Check for other references to this file (by hash)
|
||||
fileHash = file.fileHash
|
||||
if fileHash:
|
||||
otherReferences = [f for f in self.db.getRecordset("files", recordFilter={"fileHash": fileHash})
|
||||
otherReferences = [f for f in self.db.getRecordset(FileItem, recordFilter={"fileHash": fileHash})
|
||||
if f["id"] != fileId]
|
||||
|
||||
# Only delete associated fileData if no other references exist
|
||||
if not otherReferences:
|
||||
try:
|
||||
fileDataEntries = self.db.getRecordset("fileData", recordFilter={"id": fileId})
|
||||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||||
if fileDataEntries:
|
||||
self.db.recordDelete("fileData", fileId)
|
||||
self.db.recordDelete(FileData, fileId)
|
||||
logger.debug(f"FileData for file {fileId} deleted")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error deleting FileData for file {fileId}: {str(e)}")
|
||||
|
||||
# Delete the FileItem entry
|
||||
success = self.db.recordDelete("files", fileId)
|
||||
success = self.db.recordDelete(FileItem, fileId)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("files")
|
||||
|
||||
return success
|
||||
|
||||
|
|
@ -709,10 +698,9 @@ class ComponentObjects:
|
|||
"base64Encoded": base64Encoded
|
||||
}
|
||||
|
||||
self.db.recordCreate("fileData", fileDataObj)
|
||||
self.db.recordCreate(FileData, fileDataObj)
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
self._clearTableCache("fileData")
|
||||
|
||||
logger.debug(f"Successfully stored data for file {fileId} (base64Encoded: {base64Encoded})")
|
||||
return True
|
||||
|
|
@ -730,7 +718,7 @@ class ComponentObjects:
|
|||
|
||||
import base64
|
||||
|
||||
fileDataEntries = self.db.getRecordset("fileData", recordFilter={"id": fileId})
|
||||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||||
if not fileDataEntries:
|
||||
logger.warning(f"No data found for file ID {fileId}")
|
||||
return None
|
||||
|
|
@ -830,7 +818,7 @@ class ComponentObjects:
|
|||
"""Saves an uploaded file if user has permission."""
|
||||
try:
|
||||
# Check file creation permission
|
||||
if not self._canModify("files"):
|
||||
if not self._canModify(FileItem):
|
||||
raise PermissionError("No permission to upload files")
|
||||
|
||||
logger.debug(f"Starting upload process for file: {fileName}")
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def get_token_status_for_connection(interface, connection_id: str) -> tuple[str,
|
|||
try:
|
||||
# Query tokens table for the latest token for this connection
|
||||
tokens = interface.db.getRecordset(
|
||||
table="tokens",
|
||||
Token,
|
||||
recordFilter={"connectionId": connection_id}
|
||||
)
|
||||
|
||||
|
|
@ -93,9 +93,6 @@ async def get_connections(
|
|||
try:
|
||||
interface = getInterface(currentUser)
|
||||
|
||||
# Clear connections cache to ensure fresh data
|
||||
interface.db.clearTableCache("connections")
|
||||
|
||||
# SECURITY FIX: All users (including admins) can only see their own connections
|
||||
# This prevents admin from seeing other users' connections and causing confusion
|
||||
connections = interface.getUserConnections(currentUser.id)
|
||||
|
|
@ -179,10 +176,8 @@ async def create_connection(
|
|||
)
|
||||
|
||||
# Save connection record - models now handle timestamp serialization automatically
|
||||
interface.db.recordModify("connections", connection.id, connection.to_dict())
|
||||
interface.db.recordModify(UserConnection, connection.id, connection.to_dict())
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
interface.db.clearTableCache("connections")
|
||||
|
||||
return connection
|
||||
|
||||
|
|
@ -235,10 +230,8 @@ async def update_connection(
|
|||
connection.lastChecked = get_utc_timestamp()
|
||||
|
||||
# Update connection - models now handle timestamp serialization automatically
|
||||
interface.db.recordModify("connections", connectionId, connection.to_dict())
|
||||
interface.db.recordModify(UserConnection, connectionId, connection.to_dict())
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
interface.db.clearTableCache("connections")
|
||||
|
||||
# Get token status for the updated connection
|
||||
token_status, token_expires_at = get_token_status_for_connection(interface, connectionId)
|
||||
|
|
@ -372,10 +365,8 @@ async def disconnect_service(
|
|||
connection.lastChecked = get_utc_timestamp()
|
||||
|
||||
# Update connection record - models now handle timestamp serialization automatically
|
||||
interface.db.recordModify("connections", connectionId, connection.to_dict())
|
||||
interface.db.recordModify(UserConnection, connectionId, connection.to_dict())
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
interface.db.clearTableCache("connections")
|
||||
|
||||
return {"message": "Service disconnected successfully"}
|
||||
|
||||
|
|
|
|||
|
|
@ -173,7 +173,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
rootInterface = getRootInterface()
|
||||
# Prefer connection flow reuse; fallback to user access token
|
||||
if connection_id:
|
||||
existing_tokens = rootInterface.db.getRecordset("tokens", recordFilter={
|
||||
from modules.interfaces.interfaceAppModel import Token
|
||||
existing_tokens = rootInterface.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": connection_id,
|
||||
"authority": AuthAuthority.GOOGLE
|
||||
})
|
||||
|
|
@ -182,7 +183,7 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
existing_tokens.sort(key=lambda x: x.get("createdAt", 0), reverse=True)
|
||||
token_response["refresh_token"] = existing_tokens[0].get("tokenRefresh", "")
|
||||
if not token_response.get("refresh_token") and user_id:
|
||||
existing_access_tokens = rootInterface.db.getRecordset("tokens", recordFilter={
|
||||
existing_access_tokens = rootInterface.db.getRecordset(Token, recordFilter={
|
||||
"userId": user_id,
|
||||
"connectionId": None,
|
||||
"authority": AuthAuthority.GOOGLE
|
||||
|
|
@ -358,10 +359,9 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
connection.externalEmail = user_info.get("email")
|
||||
|
||||
# Update connection record directly
|
||||
rootInterface.db.recordModify("connections", connection_id, connection.to_dict())
|
||||
from modules.interfaces.interfaceAppModel import UserConnection
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.to_dict())
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
rootInterface.db.clearTableCache("connections")
|
||||
|
||||
# Save token
|
||||
token = Token(
|
||||
|
|
@ -543,7 +543,7 @@ async def refresh_token(
|
|||
google_connection.status = ConnectionStatus.ACTIVE
|
||||
|
||||
# Save updated connection
|
||||
appInterface.db.recordModify("connections", google_connection.id, google_connection.to_dict())
|
||||
appInterface.db.recordModify(UserConnection, google_connection.id, google_connection.to_dict())
|
||||
|
||||
# Calculate time until expiration
|
||||
current_time = get_utc_timestamp()
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ async def login(
|
|||
rootInterface = getRootInterface()
|
||||
|
||||
# Get default mandate ID
|
||||
defaultMandateId = rootInterface.getInitialId("mandates")
|
||||
from modules.interfaces.interfaceAppModel import Mandate
|
||||
defaultMandateId = rootInterface.getInitialId(Mandate)
|
||||
if not defaultMandateId:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -146,7 +147,8 @@ async def register_user(
|
|||
appInterface = getRootInterface()
|
||||
|
||||
# Get default mandate ID
|
||||
defaultMandateId = appInterface.getInitialId("mandates")
|
||||
from modules.interfaces.interfaceAppModel import Mandate
|
||||
defaultMandateId = appInterface.getInitialId(Mandate)
|
||||
if not defaultMandateId:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
|
|||
|
|
@ -309,10 +309,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
connection.externalEmail = user_info.get("mail")
|
||||
|
||||
# Update connection record directly
|
||||
rootInterface.db.recordModify("connections", connection_id, connection.to_dict())
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.to_dict())
|
||||
|
||||
# Clear cache to ensure fresh data
|
||||
rootInterface.db.clearTableCache("connections")
|
||||
|
||||
# Save token
|
||||
|
||||
|
|
@ -524,7 +522,7 @@ async def refresh_token(
|
|||
msft_connection.status = ConnectionStatus.ACTIVE
|
||||
|
||||
# Save updated connection
|
||||
appInterface.db.recordModify("connections", msft_connection.id, msft_connection.to_dict())
|
||||
appInterface.db.recordModify(UserConnection, msft_connection.id, msft_connection.to_dict())
|
||||
|
||||
# Calculate time until expiration
|
||||
current_time = get_utc_timestamp()
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ async def get_workflows(
|
|||
"""Get all workflows for the current user."""
|
||||
try:
|
||||
appInterface = getInterface(currentUser)
|
||||
workflows_data = appInterface.getAllWorkflows()
|
||||
workflows_data = appInterface.getWorkflows()
|
||||
|
||||
# Convert raw dictionaries to ChatWorkflow objects
|
||||
workflows = []
|
||||
|
|
@ -136,7 +136,7 @@ async def update_workflow(
|
|||
workflowInterface = getInterface(currentUser)
|
||||
|
||||
# Get raw workflow data from database to check permissions
|
||||
workflows = workflowInterface.db.getRecordset("workflows", recordFilter={"id": workflowId})
|
||||
workflows = workflowInterface.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
if not workflows:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -225,7 +225,7 @@ async def get_workflow_logs(
|
|||
)
|
||||
|
||||
# Get all logs
|
||||
allLogs = interfaceChat.getWorkflowLogs(workflowId)
|
||||
allLogs = interfaceChat.getLogs(workflowId)
|
||||
|
||||
# Apply selective data transfer if logId is provided
|
||||
if logId:
|
||||
|
|
@ -268,7 +268,7 @@ async def get_workflow_messages(
|
|||
)
|
||||
|
||||
# Get all messages
|
||||
allMessages = interfaceChat.getWorkflowMessages(workflowId)
|
||||
allMessages = interfaceChat.getMessages(workflowId)
|
||||
|
||||
# Apply selective data transfer if messageId is provided
|
||||
if messageId:
|
||||
|
|
@ -356,7 +356,7 @@ async def delete_workflow(
|
|||
interfaceChat = getServiceChat(currentUser)
|
||||
|
||||
# Get raw workflow data from database to check permissions
|
||||
workflows = interfaceChat.db.getRecordset("workflows", recordFilter={"id": workflowId})
|
||||
workflows = interfaceChat.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
if not workflows:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -419,7 +419,7 @@ async def delete_workflow_message(
|
|||
)
|
||||
|
||||
# Delete the message
|
||||
success = interfaceChat.deleteWorkflowMessage(workflowId, messageId)
|
||||
success = interfaceChat.deleteMessage(workflowId, messageId)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -76,12 +76,12 @@ class WorkflowManager:
|
|||
"taskProgress": "pending",
|
||||
"actionProgress": "pending"
|
||||
}
|
||||
message = self.chatInterface.createWorkflowMessage(stopped_message)
|
||||
message = self.chatInterface.createMessage(stopped_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
# Add log entry
|
||||
self.chatInterface.createWorkflowLog({
|
||||
self.chatInterface.createLog({
|
||||
"workflowId": workflow.id,
|
||||
"message": "Workflow stopped by user",
|
||||
"type": "warning",
|
||||
|
|
@ -120,12 +120,12 @@ class WorkflowManager:
|
|||
"taskProgress": "fail",
|
||||
"actionProgress": "fail"
|
||||
}
|
||||
message = self.chatInterface.createWorkflowMessage(error_message)
|
||||
message = self.chatInterface.createMessage(error_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
# Add error log entry
|
||||
self.chatInterface.createWorkflowLog({
|
||||
self.chatInterface.createLog({
|
||||
"workflowId": workflow.id,
|
||||
"message": f"Workflow failed: {str(e)}",
|
||||
"type": "error",
|
||||
|
|
@ -165,16 +165,19 @@ class WorkflowManager:
|
|||
"actionProgress": "pending"
|
||||
}
|
||||
|
||||
# Add documents if any
|
||||
if userInput.listFileId:
|
||||
# Process file IDs and add to message data
|
||||
documents = await self.chatManager.service.processFileIds(userInput.listFileId)
|
||||
messageData["documents"] = documents
|
||||
|
||||
# Create message using interface
|
||||
message = self.chatInterface.createWorkflowMessage(messageData)
|
||||
# Create message first to get messageId
|
||||
message = self.chatInterface.createMessage(messageData)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
# Add documents if any, now with messageId
|
||||
if userInput.listFileId:
|
||||
# Process file IDs and add to message data
|
||||
documents = await self.chatManager.service.processFileIds(userInput.listFileId, message.id)
|
||||
message.documents = documents
|
||||
# Update the message with documents in database
|
||||
self.chatInterface.updateMessage(message.id, {"documents": [doc.to_dict() for doc in documents]})
|
||||
|
||||
return message
|
||||
else:
|
||||
raise Exception("Failed to create first message")
|
||||
|
|
@ -241,7 +244,7 @@ class WorkflowManager:
|
|||
}
|
||||
|
||||
# Create message using interface
|
||||
message = self.chatInterface.createWorkflowMessage(messageData)
|
||||
message = self.chatInterface.createMessage(messageData)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
@ -256,7 +259,7 @@ class WorkflowManager:
|
|||
})
|
||||
|
||||
# Add completion log entry
|
||||
self.chatInterface.createWorkflowLog({
|
||||
self.chatInterface.createLog({
|
||||
"workflowId": workflow.id,
|
||||
"message": "Workflow completed",
|
||||
"type": "success",
|
||||
|
|
@ -294,7 +297,7 @@ class WorkflowManager:
|
|||
"taskProgress": "stopped",
|
||||
"actionProgress": "stopped"
|
||||
}
|
||||
message = self.chatInterface.createWorkflowMessage(stopped_message)
|
||||
message = self.chatInterface.createMessage(stopped_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
@ -326,7 +329,7 @@ class WorkflowManager:
|
|||
"taskProgress": "stopped",
|
||||
"actionProgress": "stopped"
|
||||
}
|
||||
message = self.chatInterface.createWorkflowMessage(stopped_message)
|
||||
message = self.chatInterface.createMessage(stopped_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
@ -341,7 +344,7 @@ class WorkflowManager:
|
|||
})
|
||||
|
||||
# Add stopped log entry
|
||||
self.chatInterface.createWorkflowLog({
|
||||
self.chatInterface.createLog({
|
||||
"workflowId": workflow.id,
|
||||
"message": "Workflow stopped by user",
|
||||
"type": "warning",
|
||||
|
|
@ -368,7 +371,7 @@ class WorkflowManager:
|
|||
"taskProgress": "fail",
|
||||
"actionProgress": "fail"
|
||||
}
|
||||
message = self.chatInterface.createWorkflowMessage(error_message)
|
||||
message = self.chatInterface.createMessage(error_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
@ -383,7 +386,7 @@ class WorkflowManager:
|
|||
})
|
||||
|
||||
# Add failed log entry
|
||||
self.chatInterface.createWorkflowLog({
|
||||
self.chatInterface.createLog({
|
||||
"workflowId": workflow.id,
|
||||
"message": f"Workflow failed: {workflow_result.error or 'Unknown error'}",
|
||||
"type": "error",
|
||||
|
|
@ -411,7 +414,7 @@ class WorkflowManager:
|
|||
"actionProgress": "success"
|
||||
}
|
||||
|
||||
message = self.chatInterface.createWorkflowMessage(summary_message)
|
||||
message = self.chatInterface.createMessage(summary_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
@ -426,7 +429,7 @@ class WorkflowManager:
|
|||
})
|
||||
|
||||
# Add completion log entry
|
||||
self.chatInterface.createWorkflowLog({
|
||||
self.chatInterface.createLog({
|
||||
"workflowId": workflow.id,
|
||||
"message": "Workflow completed successfully",
|
||||
"type": "success",
|
||||
|
|
@ -454,7 +457,7 @@ class WorkflowManager:
|
|||
"taskProgress": "fail",
|
||||
"actionProgress": "fail"
|
||||
}
|
||||
message = self.chatInterface.createWorkflowMessage(error_message)
|
||||
message = self.chatInterface.createMessage(error_message)
|
||||
if message:
|
||||
workflow.messages.append(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ TODO
|
|||
# System
|
||||
- database
|
||||
- db initialization as separate function to create root mandate, then sysadmin with hashed passwords --> using the connector according to env configuration
|
||||
- config page for: db reset
|
||||
- settings: UI page for: db new (delete if exists and init), then to add mandate root and sysadmin, log download --> in the api to add connector settings with the according endpoints
|
||||
- access model as matrix, not as code --> to have view, add, update, delete with the rights on level table and attribute for all, my (created by me), my mandate (mandate I am in), none (no access)
|
||||
- document handling centralized
|
||||
- ai handling centralized
|
||||
- neutralizer to activate AND put back placeholders to the returned data
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
New features
|
||||
- Limiter and tracking of ip adress access
|
||||
- Sessions improved
|
||||
- user and connection consequently separated
|
||||
- seamless local and external authorities integration
|
||||
- audit trail
|
||||
- nda disclaimer in login window
|
||||
- CSRF Tokens included in forms
|
||||
1
query
Normal file
1
query
Normal file
|
|
@ -0,0 +1 @@
|
|||
postgresql
|
||||
|
|
@ -90,4 +90,7 @@ bokeh>=3.2.0,<3.4.0
|
|||
linkify-it-py>=1.0.0
|
||||
mdit-py-plugins>=0.3.0
|
||||
pyviz-comms>=2.0.0
|
||||
xyzservices>=2021.09.1
|
||||
xyzservices>=2021.09.1
|
||||
|
||||
# PostgreSQL connector dependencies
|
||||
psycopg2-binary==2.9.9
|
||||
|
|
|
|||
|
|
@ -1,237 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify concurrency improvements in DatabaseConnector.
|
||||
This script simulates multiple users accessing the database simultaneously.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
# Add the gateway directory to the path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.connectors.connectorPool import get_connector, return_connector
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_concurrent_record_operations():
|
||||
"""Test concurrent record creation, modification, and deletion."""
|
||||
|
||||
# Create temporary database directory
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_host = temp_dir
|
||||
db_database = "test_db"
|
||||
|
||||
try:
|
||||
logger.info("Starting concurrency test...")
|
||||
|
||||
def user_operation(user_id: int, operation_count: int = 10):
|
||||
"""Simulate a user performing database operations."""
|
||||
try:
|
||||
# Get a dedicated connector for this user
|
||||
db = get_connector(
|
||||
dbHost=db_host,
|
||||
dbDatabase=db_database,
|
||||
userId=f"user_{user_id}"
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(operation_count):
|
||||
# Create a record
|
||||
record = {
|
||||
"id": f"user_{user_id}_record_{i}",
|
||||
"data": f"User {user_id} data {i}",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# Create record
|
||||
created = db.recordCreate("test_table", record)
|
||||
results.append(f"Created: {created['id']}")
|
||||
|
||||
# Modify record
|
||||
record["data"] = f"Modified by user {user_id} - {i}"
|
||||
modified = db.recordModify("test_table", record["id"], record)
|
||||
results.append(f"Modified: {modified['id']}")
|
||||
|
||||
# Small delay to increase chance of race conditions
|
||||
time.sleep(0.001)
|
||||
|
||||
# Return connector to pool
|
||||
return_connector(db)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"User {user_id} error: {e}")
|
||||
return [f"Error: {e}"]
|
||||
|
||||
# Test with multiple concurrent users
|
||||
num_users = 20
|
||||
operations_per_user = 5
|
||||
|
||||
logger.info(f"Testing with {num_users} users, {operations_per_user} operations each")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_users) as executor:
|
||||
# Submit all user operations
|
||||
futures = [
|
||||
executor.submit(user_operation, user_id, operations_per_user)
|
||||
for user_id in range(num_users)
|
||||
]
|
||||
|
||||
# Collect results
|
||||
all_results = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
result = future.result()
|
||||
all_results.extend(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Future error: {e}")
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Verify data integrity
|
||||
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
|
||||
|
||||
# Check that all records exist and are consistent
|
||||
all_records = db.getRecordset("test_table")
|
||||
expected_count = num_users * operations_per_user
|
||||
|
||||
logger.info(f"Expected records: {expected_count}")
|
||||
logger.info(f"Actual records: {len(all_records)}")
|
||||
logger.info(f"Test completed in {end_time - start_time:.2f} seconds")
|
||||
|
||||
# Check for data consistency
|
||||
record_ids = set(record["id"] for record in all_records)
|
||||
expected_ids = set(f"user_{user_id}_record_{i}" for user_id in range(num_users) for i in range(operations_per_user))
|
||||
|
||||
missing_ids = expected_ids - record_ids
|
||||
extra_ids = record_ids - expected_ids
|
||||
|
||||
if missing_ids:
|
||||
logger.error(f"Missing records: {missing_ids}")
|
||||
if extra_ids:
|
||||
logger.error(f"Extra records: {extra_ids}")
|
||||
|
||||
# Check for data corruption (records with wrong user data)
|
||||
corrupted_records = []
|
||||
for record in all_records:
|
||||
record_id = record["id"]
|
||||
user_id = int(record_id.split("_")[1])
|
||||
if f"Modified by user {user_id}" not in record.get("data", ""):
|
||||
corrupted_records.append(record_id)
|
||||
|
||||
if corrupted_records:
|
||||
logger.error(f"Corrupted records: {corrupted_records}")
|
||||
|
||||
success = len(missing_ids) == 0 and len(extra_ids) == 0 and len(corrupted_records) == 0
|
||||
|
||||
if success:
|
||||
logger.info("✅ Concurrency test PASSED - No data corruption detected")
|
||||
else:
|
||||
logger.error("❌ Concurrency test FAILED - Data corruption detected")
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
logger.info("Cleaned up temporary directory")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up: {e}")
|
||||
|
||||
def test_metadata_consistency():
|
||||
"""Test that metadata operations are atomic."""
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_host = temp_dir
|
||||
db_database = "test_metadata"
|
||||
|
||||
try:
|
||||
logger.info("Testing metadata consistency...")
|
||||
|
||||
def concurrent_metadata_operations(user_id: int):
|
||||
"""Perform concurrent metadata operations."""
|
||||
db = get_connector(
|
||||
dbHost=db_host,
|
||||
dbDatabase=db_database,
|
||||
userId=f"user_{user_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Create multiple records rapidly
|
||||
for i in range(10):
|
||||
record = {
|
||||
"id": f"user_{user_id}_meta_{i}",
|
||||
"data": f"Metadata test {user_id}-{i}"
|
||||
}
|
||||
db.recordCreate("metadata_test", record)
|
||||
time.sleep(0.001) # Small delay
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Metadata test error for user {user_id}: {e}")
|
||||
return False
|
||||
finally:
|
||||
return_connector(db)
|
||||
|
||||
# Run concurrent metadata operations
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(concurrent_metadata_operations, i) for i in range(10)]
|
||||
results = [future.result() for future in as_completed(futures)]
|
||||
|
||||
# Verify metadata consistency
|
||||
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
|
||||
records = db.getRecordset("metadata_test")
|
||||
|
||||
# Check that metadata is consistent
|
||||
metadata = db._loadTableMetadata("metadata_test")
|
||||
expected_count = len(records)
|
||||
actual_count = len(metadata["recordIds"])
|
||||
|
||||
logger.info(f"Expected record count: {expected_count}")
|
||||
logger.info(f"Metadata record count: {actual_count}")
|
||||
|
||||
success = expected_count == actual_count
|
||||
|
||||
if success:
|
||||
logger.info("✅ Metadata consistency test PASSED")
|
||||
else:
|
||||
logger.error("❌ Metadata consistency test FAILED")
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting concurrency tests...")
|
||||
|
||||
# Test 1: Concurrent record operations
|
||||
test1_passed = test_concurrent_record_operations()
|
||||
|
||||
# Test 2: Metadata consistency
|
||||
test2_passed = test_metadata_consistency()
|
||||
|
||||
# Overall result
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All concurrency tests PASSED!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("💥 Some concurrency tests FAILED!")
|
||||
sys.exit(1)
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
"""Tests for Tavliy web search."""
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
from modules.interfaces.interfaceChatModel import ActionResult
|
||||
from gateway.modules.interfaces.interfaceWebModel import (
|
||||
WebSearchRequest,
|
||||
WebCrawlRequest,
|
||||
WebScrapeRequest,
|
||||
)
|
||||
from gateway.modules.connectors.connectorWebTavily import ConnectorTavily
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.expensive
|
||||
async def test_tavily_connector_search_test_live_api():
|
||||
logger.info("Testing Tavliy connector search with live API calls")
|
||||
|
||||
# Test request
|
||||
request = WebSearchRequest(query="How old is the Earth?", max_results=5)
|
||||
|
||||
# Tavily instance
|
||||
connectorWebTavily = await ConnectorTavily.create()
|
||||
|
||||
# Search test
|
||||
action_result = await connectorWebTavily.search_urls(request=request)
|
||||
|
||||
# Check results
|
||||
assert isinstance(action_result, ActionResult)
|
||||
|
||||
logger.info("=" * 20)
|
||||
logger.info(f"Action result success status: {action_result.success}")
|
||||
logger.info(f"Action result error: {action_result.error}")
|
||||
logger.info(f"Action result label: {action_result.resultLabel}")
|
||||
|
||||
logger.info("Documents:")
|
||||
for doc in action_result.documents:
|
||||
logger.info("-" * 10)
|
||||
logger.info(f" - Document Name: {doc.documentName}")
|
||||
logger.info(f" - Document Mime Type: {doc.mimeType}")
|
||||
logger.info(f" - Document Data: {doc.documentData}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.expensive
|
||||
async def test_tavily_connector_crawl_test_live_api():
|
||||
logger.info("Testing Tavily connector crawl with live API calls")
|
||||
|
||||
# Test request
|
||||
urls = [
|
||||
"https://en.wikipedia.org/wiki/Earth",
|
||||
"https://valueon.ch",
|
||||
]
|
||||
request = WebCrawlRequest(urls=urls)
|
||||
|
||||
# Tavily instance
|
||||
connectorWebTavily = await ConnectorTavily.create()
|
||||
|
||||
# Crawl test
|
||||
action_result = await connectorWebTavily.crawl_urls(request=request)
|
||||
|
||||
# Check results
|
||||
assert isinstance(action_result, ActionResult)
|
||||
|
||||
logger.info("=" * 20)
|
||||
logger.info(f"Action result success status: {action_result.success}")
|
||||
logger.info(f"Action result error: {action_result.error}")
|
||||
logger.info(f"Action result label: {action_result.resultLabel}")
|
||||
|
||||
logger.info("Documents:")
|
||||
for doc in action_result.documents:
|
||||
logger.info("-" * 10)
|
||||
logger.info(f" - Document Name: {doc.documentName}")
|
||||
logger.info(f" - Document Mime Type: {doc.mimeType}")
|
||||
logger.info(f" - Document Data: {doc.documentData}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.expensive
|
||||
async def test_tavily_connector_scrape_test_live_api():
|
||||
logger.info("Testing Tavily connector scrape with live API calls")
|
||||
|
||||
# Test request with query
|
||||
request = WebScrapeRequest(query="How old is the Earth?", max_results=3)
|
||||
|
||||
# Tavily instance
|
||||
connectorWebTavily = await ConnectorTavily.create()
|
||||
|
||||
# Scrape test
|
||||
action_result = await connectorWebTavily.scrape(request=request)
|
||||
|
||||
# Check results
|
||||
assert isinstance(action_result, ActionResult)
|
||||
|
||||
logger.info("=" * 20)
|
||||
logger.info(f"Action result success status: {action_result.success}")
|
||||
logger.info(f"Action result error: {action_result.error}")
|
||||
logger.info(f"Action result label: {action_result.resultLabel}")
|
||||
|
||||
logger.info("Documents:")
|
||||
for doc in action_result.documents:
|
||||
logger.info("-" * 10)
|
||||
logger.info(f" - Document Name: {doc.documentName}")
|
||||
logger.info(f" - Document Mime Type: {doc.mimeType}")
|
||||
logger.info(f" - Document Data: {doc.documentData}")
|
||||
Loading…
Reference in a new issue