database user session separation
This commit is contained in:
parent
1ff4248346
commit
98a4323b36
7 changed files with 555 additions and 71 deletions
|
|
@ -131,6 +131,11 @@ class DatabaseConnector:
|
|||
|
||||
return lock
|
||||
|
||||
def _get_table_lock(self, table: str, timeout_seconds: int = 30):
|
||||
"""Get table-level lock for metadata operations"""
|
||||
table_lock_key = f"table_{table}"
|
||||
return self._get_file_lock(table_lock_key, timeout_seconds)
|
||||
|
||||
def _ensureTableDirectory(self, table: str) -> bool:
|
||||
"""Ensures the table directory exists."""
|
||||
if table == self._systemTableName:
|
||||
|
|
@ -145,7 +150,9 @@ class DatabaseConnector:
|
|||
return False
|
||||
|
||||
def _loadTableMetadata(self, table: str) -> Dict[str, Any]:
|
||||
"""Loads table metadata (list of record IDs) without loading actual records."""
|
||||
"""Loads table metadata (list of record IDs) without loading actual records.
|
||||
NOTE: This method is safe to call without additional locking.
|
||||
"""
|
||||
if table in self._tableMetadataCache:
|
||||
return self._tableMetadataCache[table]
|
||||
|
||||
|
|
@ -159,7 +166,7 @@ class DatabaseConnector:
|
|||
try:
|
||||
if os.path.exists(tablePath):
|
||||
for fileName in os.listdir(tablePath):
|
||||
if fileName.endswith('.json'):
|
||||
if fileName.endswith('.json') and fileName != '_metadata.json':
|
||||
recordId = fileName[:-5] # Remove .json extension
|
||||
metadata["recordIds"].append(recordId)
|
||||
|
||||
|
|
@ -183,17 +190,23 @@ class DatabaseConnector:
|
|||
return None
|
||||
|
||||
def _saveRecord(self, table: str, recordId: str, record: Dict[str, Any]) -> bool:
|
||||
"""Saves a single record to the table."""
|
||||
"""Saves a single record to the table with atomic metadata operations."""
|
||||
recordPath = self._getRecordPath(table, recordId)
|
||||
lock = self._get_file_lock(recordPath)
|
||||
record_lock = self._get_file_lock(recordPath)
|
||||
table_lock = self._get_table_lock(table)
|
||||
|
||||
try:
|
||||
# Acquire lock with timeout
|
||||
if not lock.acquire(timeout=30): # 30 second timeout
|
||||
raise TimeoutError(f"Could not acquire lock for {recordPath} within 30 seconds")
|
||||
# Acquire both locks with timeout - record lock first, then table lock
|
||||
if not record_lock.acquire(timeout=30):
|
||||
raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds")
|
||||
|
||||
if not table_lock.acquire(timeout=30):
|
||||
record_lock.release()
|
||||
raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds")
|
||||
|
||||
# Record lock acquisition time
|
||||
self._lock_timeouts[recordPath] = time.time()
|
||||
self._lock_timeouts[f"table_{table}"] = time.time()
|
||||
|
||||
# Ensure table directory exists
|
||||
if not self._ensureTableDirectory(table):
|
||||
|
|
@ -239,14 +252,14 @@ class DatabaseConnector:
|
|||
# Atomic move from temp to final location
|
||||
os.replace(tempPath, recordPath)
|
||||
|
||||
# Update metadata
|
||||
# ATOMIC: Update metadata while holding both locks
|
||||
metadata = self._loadTableMetadata(table)
|
||||
if recordId not in metadata["recordIds"]:
|
||||
metadata["recordIds"].append(recordId)
|
||||
metadata["recordIds"].sort()
|
||||
self._saveTableMetadata(table, metadata)
|
||||
|
||||
# Update cache if it exists
|
||||
# Update cache if it exists (also protected by table lock)
|
||||
if table in self._tablesCache:
|
||||
# Find and update existing record or append new one
|
||||
found = False
|
||||
|
|
@ -272,14 +285,22 @@ class DatabaseConnector:
|
|||
return False
|
||||
|
||||
finally:
|
||||
# ALWAYS release lock, even on error
|
||||
# ALWAYS release both locks, even on error
|
||||
try:
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
if table_lock.locked():
|
||||
table_lock.release()
|
||||
if f"table_{table}" in self._lock_timeouts:
|
||||
del self._lock_timeouts[f"table_{table}"]
|
||||
except Exception as release_error:
|
||||
logger.error(f"Error releasing table lock for {table}: {release_error}")
|
||||
|
||||
try:
|
||||
if record_lock.locked():
|
||||
record_lock.release()
|
||||
if recordPath in self._lock_timeouts:
|
||||
del self._lock_timeouts[recordPath]
|
||||
except Exception as release_error:
|
||||
logger.error(f"Error releasing lock for {recordPath}: {release_error}")
|
||||
logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
|
||||
|
||||
def _loadTable(self, table: str) -> List[Dict[str, Any]]:
|
||||
"""Loads all records from a table folder."""
|
||||
|
|
@ -403,23 +424,14 @@ class DatabaseConnector:
|
|||
|
||||
|
||||
def _saveTableMetadata(self, table: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Saves table metadata to a metadata file."""
|
||||
"""Saves table metadata to a metadata file.
|
||||
NOTE: This method assumes the caller already holds the table lock.
|
||||
"""
|
||||
try:
|
||||
# Create metadata file path
|
||||
metadataPath = os.path.join(self._getTablePath(table), "_metadata.json")
|
||||
|
||||
# Get lock for metadata file
|
||||
lock = self._get_file_lock(metadataPath)
|
||||
|
||||
try:
|
||||
# Acquire lock with timeout
|
||||
if not lock.acquire(timeout=30):
|
||||
raise TimeoutError(f"Could not acquire lock for metadata {metadataPath} within 30 seconds")
|
||||
|
||||
# Record lock acquisition time
|
||||
self._lock_timeouts[metadataPath] = time.time()
|
||||
|
||||
# Save metadata
|
||||
# Save metadata (caller should already hold table lock)
|
||||
with open(metadataPath, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
|
@ -428,16 +440,6 @@ class DatabaseConnector:
|
|||
|
||||
return True
|
||||
|
||||
finally:
|
||||
# ALWAYS release lock
|
||||
try:
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
if metadataPath in self._lock_timeouts:
|
||||
del self._lock_timeouts[metadataPath]
|
||||
except Exception as release_error:
|
||||
logger.error(f"Error releasing metadata lock for {metadataPath}: {release_error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata for table {table}: {e}")
|
||||
return False
|
||||
|
|
@ -582,7 +584,24 @@ class DatabaseConnector:
|
|||
return existingRecord
|
||||
|
||||
def recordDelete(self, table: str, recordId: str) -> bool:
|
||||
"""Deletes a record from the table."""
|
||||
"""Deletes a record from the table with atomic metadata operations."""
|
||||
recordPath = self._getRecordPath(table, recordId)
|
||||
record_lock = self._get_file_lock(recordPath)
|
||||
table_lock = self._get_table_lock(table)
|
||||
|
||||
try:
|
||||
# Acquire both locks with timeout - record lock first, then table lock
|
||||
if not record_lock.acquire(timeout=30):
|
||||
raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds")
|
||||
|
||||
if not table_lock.acquire(timeout=30):
|
||||
record_lock.release()
|
||||
raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds")
|
||||
|
||||
# Record lock acquisition time
|
||||
self._lock_timeouts[recordPath] = time.time()
|
||||
self._lock_timeouts[f"table_{table}"] = time.time()
|
||||
|
||||
# Load metadata
|
||||
metadata = self._loadTableMetadata(table)
|
||||
|
||||
|
|
@ -596,26 +615,43 @@ class DatabaseConnector:
|
|||
logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table")
|
||||
|
||||
# Delete the record file
|
||||
recordPath = self._getRecordPath(table, recordId)
|
||||
try:
|
||||
if os.path.exists(recordPath):
|
||||
os.remove(recordPath)
|
||||
|
||||
# Update metadata cache
|
||||
# ATOMIC: Update metadata while holding both locks
|
||||
metadata["recordIds"].remove(recordId)
|
||||
self._tableMetadataCache[table] = metadata
|
||||
self._saveTableMetadata(table, metadata)
|
||||
|
||||
# Update table cache if it exists
|
||||
# Update table cache if it exists (also protected by table lock)
|
||||
if table in self._tablesCache:
|
||||
self._tablesCache[table] = [r for r in self._tablesCache[table] if r.get("id") != recordId]
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting record file {recordPath}: {e}")
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
# ALWAYS release both locks, even on error
|
||||
try:
|
||||
if table_lock.locked():
|
||||
table_lock.release()
|
||||
if f"table_{table}" in self._lock_timeouts:
|
||||
del self._lock_timeouts[f"table_{table}"]
|
||||
except Exception as release_error:
|
||||
logger.error(f"Error releasing table lock for {table}: {release_error}")
|
||||
|
||||
try:
|
||||
if record_lock.locked():
|
||||
record_lock.release()
|
||||
if recordPath in self._lock_timeouts:
|
||||
del self._lock_timeouts[recordPath]
|
||||
except Exception as release_error:
|
||||
logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
|
||||
|
||||
def getInitialId(self, table: str) -> Optional[str]:
|
||||
"""Returns the initial ID for a table."""
|
||||
systemData = self._loadSystemTable()
|
||||
|
|
|
|||
178
modules/connectors/connectorPool.py
Normal file
178
modules/connectors/connectorPool.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
import threading
|
||||
import queue
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from .connectorDbJson import DatabaseConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DatabaseConnectorPool:
|
||||
"""
|
||||
A connection pool for DatabaseConnector instances to manage resources efficiently
|
||||
and ensure proper isolation between users.
|
||||
"""
|
||||
|
||||
def __init__(self, max_connections: int = 100, max_idle_time: int = 300):
|
||||
"""
|
||||
Initialize the connection pool.
|
||||
|
||||
Args:
|
||||
max_connections: Maximum number of connections in the pool
|
||||
max_idle_time: Maximum idle time in seconds before connection is considered stale
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self.max_idle_time = max_idle_time
|
||||
self._pool = queue.Queue(maxsize=max_connections)
|
||||
self._created_connections = 0
|
||||
self._lock = threading.Lock()
|
||||
self._connection_times = {} # Track when connections were created
|
||||
|
||||
def _create_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
|
||||
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
|
||||
"""Create a new DatabaseConnector instance."""
|
||||
with self._lock:
|
||||
if self._created_connections >= self.max_connections:
|
||||
raise RuntimeError(f"Maximum connections ({self.max_connections}) exceeded")
|
||||
|
||||
self._created_connections += 1
|
||||
logger.debug(f"Creating new database connector (total: {self._created_connections})")
|
||||
|
||||
connector = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
userId=userId
|
||||
)
|
||||
|
||||
# Track creation time
|
||||
connector_id = id(connector)
|
||||
self._connection_times[connector_id] = time.time()
|
||||
|
||||
return connector
|
||||
|
||||
def get_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
|
||||
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
|
||||
"""
|
||||
Get a database connector from the pool or create a new one.
|
||||
|
||||
Args:
|
||||
dbHost: Database host path
|
||||
dbDatabase: Database name
|
||||
dbUser: Database user (optional)
|
||||
dbPassword: Database password (optional)
|
||||
userId: User ID for context (optional)
|
||||
|
||||
Returns:
|
||||
DatabaseConnector instance
|
||||
"""
|
||||
try:
|
||||
# Try to get an existing connector from the pool
|
||||
connector = self._pool.get_nowait()
|
||||
|
||||
# Check if connector is stale
|
||||
connector_id = id(connector)
|
||||
if connector_id in self._connection_times:
|
||||
idle_time = time.time() - self._connection_times[connector_id]
|
||||
if idle_time > self.max_idle_time:
|
||||
logger.debug(f"Connector {connector_id} is stale (idle: {idle_time}s), creating new one")
|
||||
# Remove stale connector from tracking
|
||||
if connector_id in self._connection_times:
|
||||
del self._connection_times[connector_id]
|
||||
# Create new connector
|
||||
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
|
||||
|
||||
# Update user context if provided
|
||||
if userId is not None:
|
||||
connector.updateContext(userId)
|
||||
|
||||
logger.debug(f"Reusing existing connector {connector_id}")
|
||||
return connector
|
||||
|
||||
except queue.Empty:
|
||||
# Pool is empty, create new connector
|
||||
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
|
||||
|
||||
def return_connector(self, connector: DatabaseConnector) -> None:
|
||||
"""
|
||||
Return a connector to the pool for reuse.
|
||||
|
||||
Args:
|
||||
connector: DatabaseConnector instance to return
|
||||
"""
|
||||
try:
|
||||
# Update connection time
|
||||
connector_id = id(connector)
|
||||
self._connection_times[connector_id] = time.time()
|
||||
|
||||
# Try to return to pool
|
||||
self._pool.put_nowait(connector)
|
||||
logger.debug(f"Returned connector {connector_id} to pool")
|
||||
|
||||
except queue.Full:
|
||||
# Pool is full, discard connector
|
||||
logger.debug(f"Pool full, discarding connector {id(connector)}")
|
||||
with self._lock:
|
||||
self._created_connections -= 1
|
||||
if id(connector) in self._connection_times:
|
||||
del self._connection_times[id(connector)]
|
||||
|
||||
def cleanup_stale_connections(self) -> int:
|
||||
"""
|
||||
Clean up stale connections from the pool.
|
||||
|
||||
Returns:
|
||||
Number of connections cleaned up
|
||||
"""
|
||||
cleaned = 0
|
||||
current_time = time.time()
|
||||
|
||||
# Check all tracked connections
|
||||
stale_connectors = []
|
||||
for connector_id, creation_time in list(self._connection_times.items()):
|
||||
if current_time - creation_time > self.max_idle_time:
|
||||
stale_connectors.append(connector_id)
|
||||
|
||||
# Remove stale connections from tracking
|
||||
for connector_id in stale_connectors:
|
||||
if connector_id in self._connection_times:
|
||||
del self._connection_times[connector_id]
|
||||
cleaned += 1
|
||||
|
||||
logger.debug(f"Cleaned up {cleaned} stale connections")
|
||||
return cleaned
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"max_connections": self.max_connections,
|
||||
"created_connections": self._created_connections,
|
||||
"available_connections": self._pool.qsize(),
|
||||
"tracked_connections": len(self._connection_times)
|
||||
}
|
||||
|
||||
# Global pool instance
|
||||
_connector_pool = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
def get_connector_pool() -> DatabaseConnectorPool:
|
||||
"""Get the global connector pool instance."""
|
||||
global _connector_pool
|
||||
if _connector_pool is None:
|
||||
with _pool_lock:
|
||||
if _connector_pool is None:
|
||||
_connector_pool = DatabaseConnectorPool()
|
||||
return _connector_pool
|
||||
|
||||
def get_connector(dbHost: str, dbDatabase: str, dbUser: str = None,
|
||||
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
|
||||
"""Get a database connector from the global pool."""
|
||||
pool = get_connector_pool()
|
||||
return pool.get_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
|
||||
|
||||
def return_connector(connector: DatabaseConnector) -> None:
|
||||
"""Return a database connector to the global pool."""
|
||||
pool = get_connector_pool()
|
||||
pool.return_connector(connector)
|
||||
|
|
@ -13,6 +13,7 @@ from passlib.context import CryptContext
|
|||
import uuid
|
||||
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.connectors.connectorPool import get_connector, return_connector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp
|
||||
from modules.interfaces.interfaceAppAccess import AppAccess
|
||||
|
|
@ -79,8 +80,16 @@ class AppObjects:
|
|||
# Update database context
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to return connector to pool."""
|
||||
if hasattr(self, 'db') and self.db is not None:
|
||||
try:
|
||||
return_connector(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Error returning connector to pool: {e}")
|
||||
|
||||
def _initializeDatabase(self):
|
||||
"""Initializes the database connection."""
|
||||
"""Initializes the database connection using connection pool."""
|
||||
try:
|
||||
# Get configuration values with defaults
|
||||
dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data")
|
||||
|
|
@ -91,14 +100,16 @@ class AppObjects:
|
|||
# Ensure the database directory exists
|
||||
os.makedirs(dbHost, exist_ok=True)
|
||||
|
||||
self.db = DatabaseConnector(
|
||||
# Get connector from pool with user context
|
||||
self.db = get_connector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword
|
||||
dbPassword=dbPassword,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
logger.info(f"Database initialized successfully for user {self.userId}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from modules.interfaces.interfaceAppModel import User
|
|||
|
||||
# DYNAMIC PART: Connectors to the Interface
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.connectors.connectorPool import get_connector, return_connector
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
|
||||
# Basic Configurations
|
||||
|
|
@ -73,6 +74,14 @@ class ChatObjects:
|
|||
# Update database context
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to return connector to pool."""
|
||||
if hasattr(self, 'db') and self.db is not None:
|
||||
try:
|
||||
return_connector(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Error returning connector to pool: {e}")
|
||||
|
||||
logger.debug(f"User context set: userId={self.userId}, mandateId={self.mandateId}")
|
||||
|
||||
def _initializeDatabase(self):
|
||||
|
|
@ -87,11 +96,12 @@ class ChatObjects:
|
|||
# Ensure the database directory exists
|
||||
os.makedirs(dbHost, exist_ok=True)
|
||||
|
||||
self.db = DatabaseConnector(
|
||||
self.db = get_connector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword
|
||||
dbPassword=dbPassword,
|
||||
userId=self.userId
|
||||
)
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from modules.interfaces.interfaceAppModel import User
|
|||
|
||||
# DYNAMIC PART: Connectors to the Interface
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.connectors.connectorPool import get_connector, return_connector
|
||||
|
||||
# Basic Configurations
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
|
@ -88,6 +89,14 @@ class ComponentObjects:
|
|||
# Update database context
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to return connector to pool."""
|
||||
if hasattr(self, 'db') and self.db is not None:
|
||||
try:
|
||||
return_connector(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Error returning connector to pool: {e}")
|
||||
|
||||
logger.debug(f"User context set: userId={self.userId}")
|
||||
|
||||
def _initializeDatabase(self):
|
||||
|
|
@ -102,11 +111,12 @@ class ComponentObjects:
|
|||
# Ensure the database directory exists
|
||||
os.makedirs(dbHost, exist_ok=True)
|
||||
|
||||
self.db = DatabaseConnector(
|
||||
self.db = get_connector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword
|
||||
dbPassword=dbPassword,
|
||||
userId=self.userId if hasattr(self, 'userId') else None
|
||||
)
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@
|
|||
TODO
|
||||
|
||||
# System
|
||||
- sharepoint to fix
|
||||
- database
|
||||
- db initialization as separate function to create root mandate, then sysadmin with hashed passwords --> using the connector according to env configuration
|
||||
- config page for: db reset
|
||||
- document handling centralized
|
||||
- ai handling centralized
|
||||
- neutralizer to activate AND put back placeholders to the returned data
|
||||
|
|
|
|||
237
test_concurrency_fixes.py
Normal file
237
test_concurrency_fixes.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify concurrency improvements in DatabaseConnector.
|
||||
This script simulates multiple users accessing the database simultaneously.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
# Add the gateway directory to the path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from modules.connectors.connectorDbJson import DatabaseConnector
|
||||
from modules.connectors.connectorPool import get_connector, return_connector
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_concurrent_record_operations():
|
||||
"""Test concurrent record creation, modification, and deletion."""
|
||||
|
||||
# Create temporary database directory
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_host = temp_dir
|
||||
db_database = "test_db"
|
||||
|
||||
try:
|
||||
logger.info("Starting concurrency test...")
|
||||
|
||||
def user_operation(user_id: int, operation_count: int = 10):
|
||||
"""Simulate a user performing database operations."""
|
||||
try:
|
||||
# Get a dedicated connector for this user
|
||||
db = get_connector(
|
||||
dbHost=db_host,
|
||||
dbDatabase=db_database,
|
||||
userId=f"user_{user_id}"
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(operation_count):
|
||||
# Create a record
|
||||
record = {
|
||||
"id": f"user_{user_id}_record_{i}",
|
||||
"data": f"User {user_id} data {i}",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# Create record
|
||||
created = db.recordCreate("test_table", record)
|
||||
results.append(f"Created: {created['id']}")
|
||||
|
||||
# Modify record
|
||||
record["data"] = f"Modified by user {user_id} - {i}"
|
||||
modified = db.recordModify("test_table", record["id"], record)
|
||||
results.append(f"Modified: {modified['id']}")
|
||||
|
||||
# Small delay to increase chance of race conditions
|
||||
time.sleep(0.001)
|
||||
|
||||
# Return connector to pool
|
||||
return_connector(db)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"User {user_id} error: {e}")
|
||||
return [f"Error: {e}"]
|
||||
|
||||
# Test with multiple concurrent users
|
||||
num_users = 20
|
||||
operations_per_user = 5
|
||||
|
||||
logger.info(f"Testing with {num_users} users, {operations_per_user} operations each")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_users) as executor:
|
||||
# Submit all user operations
|
||||
futures = [
|
||||
executor.submit(user_operation, user_id, operations_per_user)
|
||||
for user_id in range(num_users)
|
||||
]
|
||||
|
||||
# Collect results
|
||||
all_results = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
result = future.result()
|
||||
all_results.extend(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Future error: {e}")
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Verify data integrity
|
||||
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
|
||||
|
||||
# Check that all records exist and are consistent
|
||||
all_records = db.getRecordset("test_table")
|
||||
expected_count = num_users * operations_per_user
|
||||
|
||||
logger.info(f"Expected records: {expected_count}")
|
||||
logger.info(f"Actual records: {len(all_records)}")
|
||||
logger.info(f"Test completed in {end_time - start_time:.2f} seconds")
|
||||
|
||||
# Check for data consistency
|
||||
record_ids = set(record["id"] for record in all_records)
|
||||
expected_ids = set(f"user_{user_id}_record_{i}" for user_id in range(num_users) for i in range(operations_per_user))
|
||||
|
||||
missing_ids = expected_ids - record_ids
|
||||
extra_ids = record_ids - expected_ids
|
||||
|
||||
if missing_ids:
|
||||
logger.error(f"Missing records: {missing_ids}")
|
||||
if extra_ids:
|
||||
logger.error(f"Extra records: {extra_ids}")
|
||||
|
||||
# Check for data corruption (records with wrong user data)
|
||||
corrupted_records = []
|
||||
for record in all_records:
|
||||
record_id = record["id"]
|
||||
user_id = int(record_id.split("_")[1])
|
||||
if f"Modified by user {user_id}" not in record.get("data", ""):
|
||||
corrupted_records.append(record_id)
|
||||
|
||||
if corrupted_records:
|
||||
logger.error(f"Corrupted records: {corrupted_records}")
|
||||
|
||||
success = len(missing_ids) == 0 and len(extra_ids) == 0 and len(corrupted_records) == 0
|
||||
|
||||
if success:
|
||||
logger.info("✅ Concurrency test PASSED - No data corruption detected")
|
||||
else:
|
||||
logger.error("❌ Concurrency test FAILED - Data corruption detected")
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
logger.info("Cleaned up temporary directory")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up: {e}")
|
||||
|
||||
def test_metadata_consistency():
|
||||
"""Test that metadata operations are atomic."""
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_host = temp_dir
|
||||
db_database = "test_metadata"
|
||||
|
||||
try:
|
||||
logger.info("Testing metadata consistency...")
|
||||
|
||||
def concurrent_metadata_operations(user_id: int):
|
||||
"""Perform concurrent metadata operations."""
|
||||
db = get_connector(
|
||||
dbHost=db_host,
|
||||
dbDatabase=db_database,
|
||||
userId=f"user_{user_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Create multiple records rapidly
|
||||
for i in range(10):
|
||||
record = {
|
||||
"id": f"user_{user_id}_meta_{i}",
|
||||
"data": f"Metadata test {user_id}-{i}"
|
||||
}
|
||||
db.recordCreate("metadata_test", record)
|
||||
time.sleep(0.001) # Small delay
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Metadata test error for user {user_id}: {e}")
|
||||
return False
|
||||
finally:
|
||||
return_connector(db)
|
||||
|
||||
# Run concurrent metadata operations
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(concurrent_metadata_operations, i) for i in range(10)]
|
||||
results = [future.result() for future in as_completed(futures)]
|
||||
|
||||
# Verify metadata consistency
|
||||
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
|
||||
records = db.getRecordset("metadata_test")
|
||||
|
||||
# Check that metadata is consistent
|
||||
metadata = db._loadTableMetadata("metadata_test")
|
||||
expected_count = len(records)
|
||||
actual_count = len(metadata["recordIds"])
|
||||
|
||||
logger.info(f"Expected record count: {expected_count}")
|
||||
logger.info(f"Metadata record count: {actual_count}")
|
||||
|
||||
success = expected_count == actual_count
|
||||
|
||||
if success:
|
||||
logger.info("✅ Metadata consistency test PASSED")
|
||||
else:
|
||||
logger.error("❌ Metadata consistency test FAILED")
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting concurrency tests...")
|
||||
|
||||
# Test 1: Concurrent record operations
|
||||
test1_passed = test_concurrent_record_operations()
|
||||
|
||||
# Test 2: Metadata consistency
|
||||
test2_passed = test_metadata_consistency()
|
||||
|
||||
# Overall result
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All concurrency tests PASSED!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("💥 Some concurrency tests FAILED!")
|
||||
sys.exit(1)
|
||||
Loading…
Reference in a new issue