database user session separation

This commit is contained in:
ValueOn AG 2025-09-05 23:35:01 +02:00
parent 1ff4248346
commit 98a4323b36
7 changed files with 555 additions and 71 deletions

View file

@ -131,6 +131,11 @@ class DatabaseConnector:
return lock return lock
def _get_table_lock(self, table: str, timeout_seconds: int = 30):
"""Get table-level lock for metadata operations"""
table_lock_key = f"table_{table}"
return self._get_file_lock(table_lock_key, timeout_seconds)
def _ensureTableDirectory(self, table: str) -> bool: def _ensureTableDirectory(self, table: str) -> bool:
"""Ensures the table directory exists.""" """Ensures the table directory exists."""
if table == self._systemTableName: if table == self._systemTableName:
@ -145,7 +150,9 @@ class DatabaseConnector:
return False return False
def _loadTableMetadata(self, table: str) -> Dict[str, Any]: def _loadTableMetadata(self, table: str) -> Dict[str, Any]:
"""Loads table metadata (list of record IDs) without loading actual records.""" """Loads table metadata (list of record IDs) without loading actual records.
NOTE: This method is safe to call without additional locking.
"""
if table in self._tableMetadataCache: if table in self._tableMetadataCache:
return self._tableMetadataCache[table] return self._tableMetadataCache[table]
@ -159,7 +166,7 @@ class DatabaseConnector:
try: try:
if os.path.exists(tablePath): if os.path.exists(tablePath):
for fileName in os.listdir(tablePath): for fileName in os.listdir(tablePath):
if fileName.endswith('.json'): if fileName.endswith('.json') and fileName != '_metadata.json':
recordId = fileName[:-5] # Remove .json extension recordId = fileName[:-5] # Remove .json extension
metadata["recordIds"].append(recordId) metadata["recordIds"].append(recordId)
@ -183,17 +190,23 @@ class DatabaseConnector:
return None return None
def _saveRecord(self, table: str, recordId: str, record: Dict[str, Any]) -> bool: def _saveRecord(self, table: str, recordId: str, record: Dict[str, Any]) -> bool:
"""Saves a single record to the table.""" """Saves a single record to the table with atomic metadata operations."""
recordPath = self._getRecordPath(table, recordId) recordPath = self._getRecordPath(table, recordId)
lock = self._get_file_lock(recordPath) record_lock = self._get_file_lock(recordPath)
table_lock = self._get_table_lock(table)
try: try:
# Acquire lock with timeout # Acquire both locks with timeout - record lock first, then table lock
if not lock.acquire(timeout=30): # 30 second timeout if not record_lock.acquire(timeout=30):
raise TimeoutError(f"Could not acquire lock for {recordPath} within 30 seconds") raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds")
if not table_lock.acquire(timeout=30):
record_lock.release()
raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds")
# Record lock acquisition time # Record lock acquisition time
self._lock_timeouts[recordPath] = time.time() self._lock_timeouts[recordPath] = time.time()
self._lock_timeouts[f"table_{table}"] = time.time()
# Ensure table directory exists # Ensure table directory exists
if not self._ensureTableDirectory(table): if not self._ensureTableDirectory(table):
@ -239,14 +252,14 @@ class DatabaseConnector:
# Atomic move from temp to final location # Atomic move from temp to final location
os.replace(tempPath, recordPath) os.replace(tempPath, recordPath)
# Update metadata # ATOMIC: Update metadata while holding both locks
metadata = self._loadTableMetadata(table) metadata = self._loadTableMetadata(table)
if recordId not in metadata["recordIds"]: if recordId not in metadata["recordIds"]:
metadata["recordIds"].append(recordId) metadata["recordIds"].append(recordId)
metadata["recordIds"].sort() metadata["recordIds"].sort()
self._saveTableMetadata(table, metadata) self._saveTableMetadata(table, metadata)
# Update cache if it exists # Update cache if it exists (also protected by table lock)
if table in self._tablesCache: if table in self._tablesCache:
# Find and update existing record or append new one # Find and update existing record or append new one
found = False found = False
@ -272,14 +285,22 @@ class DatabaseConnector:
return False return False
finally: finally:
# ALWAYS release lock, even on error # ALWAYS release both locks, even on error
try: try:
if lock.locked(): if table_lock.locked():
lock.release() table_lock.release()
if f"table_{table}" in self._lock_timeouts:
del self._lock_timeouts[f"table_{table}"]
except Exception as release_error:
logger.error(f"Error releasing table lock for {table}: {release_error}")
try:
if record_lock.locked():
record_lock.release()
if recordPath in self._lock_timeouts: if recordPath in self._lock_timeouts:
del self._lock_timeouts[recordPath] del self._lock_timeouts[recordPath]
except Exception as release_error: except Exception as release_error:
logger.error(f"Error releasing lock for {recordPath}: {release_error}") logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
def _loadTable(self, table: str) -> List[Dict[str, Any]]: def _loadTable(self, table: str) -> List[Dict[str, Any]]:
"""Loads all records from a table folder.""" """Loads all records from a table folder."""
@ -403,40 +424,21 @@ class DatabaseConnector:
def _saveTableMetadata(self, table: str, metadata: Dict[str, Any]) -> bool: def _saveTableMetadata(self, table: str, metadata: Dict[str, Any]) -> bool:
"""Saves table metadata to a metadata file.""" """Saves table metadata to a metadata file.
NOTE: This method assumes the caller already holds the table lock.
"""
try: try:
# Create metadata file path # Create metadata file path
metadataPath = os.path.join(self._getTablePath(table), "_metadata.json") metadataPath = os.path.join(self._getTablePath(table), "_metadata.json")
# Get lock for metadata file # Save metadata (caller should already hold table lock)
lock = self._get_file_lock(metadataPath) with open(metadataPath, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
try: # Update cache
# Acquire lock with timeout self._tableMetadataCache[table] = metadata
if not lock.acquire(timeout=30):
raise TimeoutError(f"Could not acquire lock for metadata {metadataPath} within 30 seconds") return True
# Record lock acquisition time
self._lock_timeouts[metadataPath] = time.time()
# Save metadata
with open(metadataPath, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
# Update cache
self._tableMetadataCache[table] = metadata
return True
finally:
# ALWAYS release lock
try:
if lock.locked():
lock.release()
if metadataPath in self._lock_timeouts:
del self._lock_timeouts[metadataPath]
except Exception as release_error:
logger.error(f"Error releasing metadata lock for {metadataPath}: {release_error}")
except Exception as e: except Exception as e:
logger.error(f"Error saving metadata for table {table}: {e}") logger.error(f"Error saving metadata for table {table}: {e}")
@ -582,39 +584,73 @@ class DatabaseConnector:
return existingRecord return existingRecord
def recordDelete(self, table: str, recordId: str) -> bool: def recordDelete(self, table: str, recordId: str) -> bool:
"""Deletes a record from the table.""" """Deletes a record from the table with atomic metadata operations."""
# Load metadata
metadata = self._loadTableMetadata(table)
if recordId not in metadata["recordIds"]:
return False
# Check if it's an initial record
initialId = self.getInitialId(table)
if initialId is not None and initialId == recordId:
self._removeInitialId(table)
logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table")
# Delete the record file
recordPath = self._getRecordPath(table, recordId) recordPath = self._getRecordPath(table, recordId)
record_lock = self._get_file_lock(recordPath)
table_lock = self._get_table_lock(table)
try: try:
# Acquire both locks with timeout - record lock first, then table lock
if not record_lock.acquire(timeout=30):
raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds")
if not table_lock.acquire(timeout=30):
record_lock.release()
raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds")
# Record lock acquisition time
self._lock_timeouts[recordPath] = time.time()
self._lock_timeouts[f"table_{table}"] = time.time()
# Load metadata
metadata = self._loadTableMetadata(table)
if recordId not in metadata["recordIds"]:
return False
# Check if it's an initial record
initialId = self.getInitialId(table)
if initialId is not None and initialId == recordId:
self._removeInitialId(table)
logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table")
# Delete the record file
if os.path.exists(recordPath): if os.path.exists(recordPath):
os.remove(recordPath) os.remove(recordPath)
# Update metadata cache # ATOMIC: Update metadata while holding both locks
metadata["recordIds"].remove(recordId) metadata["recordIds"].remove(recordId)
self._tableMetadataCache[table] = metadata self._saveTableMetadata(table, metadata)
# Update table cache if it exists # Update table cache if it exists (also protected by table lock)
if table in self._tablesCache: if table in self._tablesCache:
self._tablesCache[table] = [r for r in self._tablesCache[table] if r.get("id") != recordId] self._tablesCache[table] = [r for r in self._tablesCache[table] if r.get("id") != recordId]
return True return True
else:
return False
except Exception as e: except Exception as e:
logger.error(f"Error deleting record file {recordPath}: {e}") logger.error(f"Error deleting record {recordId} from table {table}: {e}")
return False return False
return False finally:
# ALWAYS release both locks, even on error
try:
if table_lock.locked():
table_lock.release()
if f"table_{table}" in self._lock_timeouts:
del self._lock_timeouts[f"table_{table}"]
except Exception as release_error:
logger.error(f"Error releasing table lock for {table}: {release_error}")
try:
if record_lock.locked():
record_lock.release()
if recordPath in self._lock_timeouts:
del self._lock_timeouts[recordPath]
except Exception as release_error:
logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
def getInitialId(self, table: str) -> Optional[str]: def getInitialId(self, table: str) -> Optional[str]:
"""Returns the initial ID for a table.""" """Returns the initial ID for a table."""

View file

@ -0,0 +1,178 @@
import threading
import queue
import time
import logging
from typing import Optional, Dict, Any
from .connectorDbJson import DatabaseConnector
logger = logging.getLogger(__name__)
class DatabaseConnectorPool:
"""
A connection pool for DatabaseConnector instances to manage resources efficiently
and ensure proper isolation between users.
"""
def __init__(self, max_connections: int = 100, max_idle_time: int = 300):
"""
Initialize the connection pool.
Args:
max_connections: Maximum number of connections in the pool
max_idle_time: Maximum idle time in seconds before connection is considered stale
"""
self.max_connections = max_connections
self.max_idle_time = max_idle_time
self._pool = queue.Queue(maxsize=max_connections)
self._created_connections = 0
self._lock = threading.Lock()
self._connection_times = {} # Track when connections were created
def _create_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
"""Create a new DatabaseConnector instance."""
with self._lock:
if self._created_connections >= self.max_connections:
raise RuntimeError(f"Maximum connections ({self.max_connections}) exceeded")
self._created_connections += 1
logger.debug(f"Creating new database connector (total: {self._created_connections})")
connector = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
dbUser=dbUser,
dbPassword=dbPassword,
userId=userId
)
# Track creation time
connector_id = id(connector)
self._connection_times[connector_id] = time.time()
return connector
def get_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None,
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
"""
Get a database connector from the pool or create a new one.
Args:
dbHost: Database host path
dbDatabase: Database name
dbUser: Database user (optional)
dbPassword: Database password (optional)
userId: User ID for context (optional)
Returns:
DatabaseConnector instance
"""
try:
# Try to get an existing connector from the pool
connector = self._pool.get_nowait()
# Check if connector is stale
connector_id = id(connector)
if connector_id in self._connection_times:
idle_time = time.time() - self._connection_times[connector_id]
if idle_time > self.max_idle_time:
logger.debug(f"Connector {connector_id} is stale (idle: {idle_time}s), creating new one")
# Remove stale connector from tracking
if connector_id in self._connection_times:
del self._connection_times[connector_id]
# Create new connector
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
# Update user context if provided
if userId is not None:
connector.updateContext(userId)
logger.debug(f"Reusing existing connector {connector_id}")
return connector
except queue.Empty:
# Pool is empty, create new connector
return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
def return_connector(self, connector: DatabaseConnector) -> None:
"""
Return a connector to the pool for reuse.
Args:
connector: DatabaseConnector instance to return
"""
try:
# Update connection time
connector_id = id(connector)
self._connection_times[connector_id] = time.time()
# Try to return to pool
self._pool.put_nowait(connector)
logger.debug(f"Returned connector {connector_id} to pool")
except queue.Full:
# Pool is full, discard connector
logger.debug(f"Pool full, discarding connector {id(connector)}")
with self._lock:
self._created_connections -= 1
if id(connector) in self._connection_times:
del self._connection_times[id(connector)]
def cleanup_stale_connections(self) -> int:
"""
Clean up stale connections from the pool.
Returns:
Number of connections cleaned up
"""
cleaned = 0
current_time = time.time()
# Check all tracked connections
stale_connectors = []
for connector_id, creation_time in list(self._connection_times.items()):
if current_time - creation_time > self.max_idle_time:
stale_connectors.append(connector_id)
# Remove stale connections from tracking
for connector_id in stale_connectors:
if connector_id in self._connection_times:
del self._connection_times[connector_id]
cleaned += 1
logger.debug(f"Cleaned up {cleaned} stale connections")
return cleaned
def get_stats(self) -> Dict[str, Any]:
"""Get pool statistics."""
with self._lock:
return {
"max_connections": self.max_connections,
"created_connections": self._created_connections,
"available_connections": self._pool.qsize(),
"tracked_connections": len(self._connection_times)
}
# Global pool instance
_connector_pool = None
_pool_lock = threading.Lock()
def get_connector_pool() -> DatabaseConnectorPool:
"""Get the global connector pool instance."""
global _connector_pool
if _connector_pool is None:
with _pool_lock:
if _connector_pool is None:
_connector_pool = DatabaseConnectorPool()
return _connector_pool
def get_connector(dbHost: str, dbDatabase: str, dbUser: str = None,
dbPassword: str = None, userId: str = None) -> DatabaseConnector:
"""Get a database connector from the global pool."""
pool = get_connector_pool()
return pool.get_connector(dbHost, dbDatabase, dbUser, dbPassword, userId)
def return_connector(connector: DatabaseConnector) -> None:
"""Return a database connector to the global pool."""
pool = get_connector_pool()
pool.return_connector(connector)

View file

@ -13,6 +13,7 @@ from passlib.context import CryptContext
import uuid import uuid
from modules.connectors.connectorDbJson import DatabaseConnector from modules.connectors.connectorDbJson import DatabaseConnector
from modules.connectors.connectorPool import get_connector, return_connector
from modules.shared.configuration import APP_CONFIG from modules.shared.configuration import APP_CONFIG
from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp
from modules.interfaces.interfaceAppAccess import AppAccess from modules.interfaces.interfaceAppAccess import AppAccess
@ -79,8 +80,16 @@ class AppObjects:
# Update database context # Update database context
self.db.updateContext(self.userId) self.db.updateContext(self.userId)
def __del__(self):
"""Cleanup method to return connector to pool."""
if hasattr(self, 'db') and self.db is not None:
try:
return_connector(self.db)
except Exception as e:
logger.error(f"Error returning connector to pool: {e}")
def _initializeDatabase(self): def _initializeDatabase(self):
"""Initializes the database connection.""" """Initializes the database connection using connection pool."""
try: try:
# Get configuration values with defaults # Get configuration values with defaults
dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data") dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data")
@ -91,14 +100,16 @@ class AppObjects:
# Ensure the database directory exists # Ensure the database directory exists
os.makedirs(dbHost, exist_ok=True) os.makedirs(dbHost, exist_ok=True)
self.db = DatabaseConnector( # Get connector from pool with user context
self.db = get_connector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,
dbPassword=dbPassword dbPassword=dbPassword,
userId=self.userId
) )
logger.info("Database initialized successfully") logger.info(f"Database initialized successfully for user {self.userId}")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize database: {str(e)}") logger.error(f"Failed to initialize database: {str(e)}")
raise raise

View file

@ -19,6 +19,7 @@ from modules.interfaces.interfaceAppModel import User
# DYNAMIC PART: Connectors to the Interface # DYNAMIC PART: Connectors to the Interface
from modules.connectors.connectorDbJson import DatabaseConnector from modules.connectors.connectorDbJson import DatabaseConnector
from modules.connectors.connectorPool import get_connector, return_connector
from modules.shared.timezoneUtils import get_utc_timestamp from modules.shared.timezoneUtils import get_utc_timestamp
# Basic Configurations # Basic Configurations
@ -72,6 +73,14 @@ class ChatObjects:
# Update database context # Update database context
self.db.updateContext(self.userId) self.db.updateContext(self.userId)
def __del__(self):
"""Cleanup method to return connector to pool."""
if hasattr(self, 'db') and self.db is not None:
try:
return_connector(self.db)
except Exception as e:
logger.error(f"Error returning connector to pool: {e}")
logger.debug(f"User context set: userId={self.userId}, mandateId={self.mandateId}") logger.debug(f"User context set: userId={self.userId}, mandateId={self.mandateId}")
@ -87,11 +96,12 @@ class ChatObjects:
# Ensure the database directory exists # Ensure the database directory exists
os.makedirs(dbHost, exist_ok=True) os.makedirs(dbHost, exist_ok=True)
self.db = DatabaseConnector( self.db = get_connector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,
dbPassword=dbPassword dbPassword=dbPassword,
userId=self.userId
) )
logger.info("Database initialized successfully") logger.info("Database initialized successfully")

View file

@ -18,6 +18,7 @@ from modules.interfaces.interfaceAppModel import User
# DYNAMIC PART: Connectors to the Interface # DYNAMIC PART: Connectors to the Interface
from modules.connectors.connectorDbJson import DatabaseConnector from modules.connectors.connectorDbJson import DatabaseConnector
from modules.connectors.connectorPool import get_connector, return_connector
# Basic Configurations # Basic Configurations
from modules.shared.configuration import APP_CONFIG from modules.shared.configuration import APP_CONFIG
@ -87,6 +88,14 @@ class ComponentObjects:
# Update database context # Update database context
self.db.updateContext(self.userId) self.db.updateContext(self.userId)
def __del__(self):
"""Cleanup method to return connector to pool."""
if hasattr(self, 'db') and self.db is not None:
try:
return_connector(self.db)
except Exception as e:
logger.error(f"Error returning connector to pool: {e}")
logger.debug(f"User context set: userId={self.userId}") logger.debug(f"User context set: userId={self.userId}")
@ -102,11 +111,12 @@ class ComponentObjects:
# Ensure the database directory exists # Ensure the database directory exists
os.makedirs(dbHost, exist_ok=True) os.makedirs(dbHost, exist_ok=True)
self.db = DatabaseConnector( self.db = get_connector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,
dbPassword=dbPassword dbPassword=dbPassword,
userId=self.userId if hasattr(self, 'userId') else None
) )
logger.info("Database initialized successfully") logger.info("Database initialized successfully")

View file

@ -2,7 +2,9 @@
TODO TODO
# System # System
- sharepoint to fix - database
- db initialization as separate function to create root mandate, then sysadmin with hashed passwords --> using the connector according to env configuration
- config page for: db reset
- document handling centralized - document handling centralized
- ai handling centralized - ai handling centralized
- neutralizer to activate AND put back placeholders to the returned data - neutralizer to activate AND put back placeholders to the returned data

237
test_concurrency_fixes.py Normal file
View file

@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""
Test script to verify concurrency improvements in DatabaseConnector.
This script simulates multiple users accessing the database simultaneously.
"""
import os
import sys
import time
import threading
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import tempfile
import shutil
# Add the gateway directory to the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from modules.connectors.connectorDbJson import DatabaseConnector
from modules.connectors.connectorPool import get_connector, return_connector
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_concurrent_record_operations():
"""Test concurrent record creation, modification, and deletion."""
# Create temporary database directory
temp_dir = tempfile.mkdtemp()
db_host = temp_dir
db_database = "test_db"
try:
logger.info("Starting concurrency test...")
def user_operation(user_id: int, operation_count: int = 10):
"""Simulate a user performing database operations."""
try:
# Get a dedicated connector for this user
db = get_connector(
dbHost=db_host,
dbDatabase=db_database,
userId=f"user_{user_id}"
)
results = []
for i in range(operation_count):
# Create a record
record = {
"id": f"user_{user_id}_record_{i}",
"data": f"User {user_id} data {i}",
"timestamp": time.time()
}
# Create record
created = db.recordCreate("test_table", record)
results.append(f"Created: {created['id']}")
# Modify record
record["data"] = f"Modified by user {user_id} - {i}"
modified = db.recordModify("test_table", record["id"], record)
results.append(f"Modified: {modified['id']}")
# Small delay to increase chance of race conditions
time.sleep(0.001)
# Return connector to pool
return_connector(db)
return results
except Exception as e:
logger.error(f"User {user_id} error: {e}")
return [f"Error: {e}"]
# Test with multiple concurrent users
num_users = 20
operations_per_user = 5
logger.info(f"Testing with {num_users} users, {operations_per_user} operations each")
start_time = time.time()
with ThreadPoolExecutor(max_workers=num_users) as executor:
# Submit all user operations
futures = [
executor.submit(user_operation, user_id, operations_per_user)
for user_id in range(num_users)
]
# Collect results
all_results = []
for future in as_completed(futures):
try:
result = future.result()
all_results.extend(result)
except Exception as e:
logger.error(f"Future error: {e}")
end_time = time.time()
# Verify data integrity
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
# Check that all records exist and are consistent
all_records = db.getRecordset("test_table")
expected_count = num_users * operations_per_user
logger.info(f"Expected records: {expected_count}")
logger.info(f"Actual records: {len(all_records)}")
logger.info(f"Test completed in {end_time - start_time:.2f} seconds")
# Check for data consistency
record_ids = set(record["id"] for record in all_records)
expected_ids = set(f"user_{user_id}_record_{i}" for user_id in range(num_users) for i in range(operations_per_user))
missing_ids = expected_ids - record_ids
extra_ids = record_ids - expected_ids
if missing_ids:
logger.error(f"Missing records: {missing_ids}")
if extra_ids:
logger.error(f"Extra records: {extra_ids}")
# Check for data corruption (records with wrong user data)
corrupted_records = []
for record in all_records:
record_id = record["id"]
user_id = int(record_id.split("_")[1])
if f"Modified by user {user_id}" not in record.get("data", ""):
corrupted_records.append(record_id)
if corrupted_records:
logger.error(f"Corrupted records: {corrupted_records}")
success = len(missing_ids) == 0 and len(extra_ids) == 0 and len(corrupted_records) == 0
if success:
logger.info("✅ Concurrency test PASSED - No data corruption detected")
else:
logger.error("❌ Concurrency test FAILED - Data corruption detected")
return success
finally:
# Cleanup
try:
shutil.rmtree(temp_dir)
logger.info("Cleaned up temporary directory")
except Exception as e:
logger.error(f"Error cleaning up: {e}")
def test_metadata_consistency():
"""Test that metadata operations are atomic."""
temp_dir = tempfile.mkdtemp()
db_host = temp_dir
db_database = "test_metadata"
try:
logger.info("Testing metadata consistency...")
def concurrent_metadata_operations(user_id: int):
"""Perform concurrent metadata operations."""
db = get_connector(
dbHost=db_host,
dbDatabase=db_database,
userId=f"user_{user_id}"
)
try:
# Create multiple records rapidly
for i in range(10):
record = {
"id": f"user_{user_id}_meta_{i}",
"data": f"Metadata test {user_id}-{i}"
}
db.recordCreate("metadata_test", record)
time.sleep(0.001) # Small delay
return True
except Exception as e:
logger.error(f"Metadata test error for user {user_id}: {e}")
return False
finally:
return_connector(db)
# Run concurrent metadata operations
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(concurrent_metadata_operations, i) for i in range(10)]
results = [future.result() for future in as_completed(futures)]
# Verify metadata consistency
db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier")
records = db.getRecordset("metadata_test")
# Check that metadata is consistent
metadata = db._loadTableMetadata("metadata_test")
expected_count = len(records)
actual_count = len(metadata["recordIds"])
logger.info(f"Expected record count: {expected_count}")
logger.info(f"Metadata record count: {actual_count}")
success = expected_count == actual_count
if success:
logger.info("✅ Metadata consistency test PASSED")
else:
logger.error("❌ Metadata consistency test FAILED")
return success
finally:
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.error(f"Error cleaning up: {e}")
if __name__ == "__main__":
logger.info("Starting concurrency tests...")
# Test 1: Concurrent record operations
test1_passed = test_concurrent_record_operations()
# Test 2: Metadata consistency
test2_passed = test_metadata_consistency()
# Overall result
if test1_passed and test2_passed:
logger.info("🎉 All concurrency tests PASSED!")
sys.exit(0)
else:
logger.error("💥 Some concurrency tests FAILED!")
sys.exit(1)