From 98a4323b363d4ebae1d4f698321ed75825d032b2 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Fri, 5 Sep 2025 23:35:01 +0200 Subject: [PATCH] database user session separation --- modules/connectors/connectorDbJson.py | 160 +++++++----- modules/connectors/connectorPool.py | 178 +++++++++++++ modules/interfaces/interfaceAppObjects.py | 19 +- modules/interfaces/interfaceChatObjects.py | 14 +- .../interfaces/interfaceComponentObjects.py | 14 +- notes/changelog.txt | 4 +- test_concurrency_fixes.py | 237 ++++++++++++++++++ 7 files changed, 555 insertions(+), 71 deletions(-) create mode 100644 modules/connectors/connectorPool.py create mode 100644 test_concurrency_fixes.py diff --git a/modules/connectors/connectorDbJson.py b/modules/connectors/connectorDbJson.py index 5ecb88dd..56111ad3 100644 --- a/modules/connectors/connectorDbJson.py +++ b/modules/connectors/connectorDbJson.py @@ -131,6 +131,11 @@ class DatabaseConnector: return lock + def _get_table_lock(self, table: str, timeout_seconds: int = 30): + """Get table-level lock for metadata operations""" + table_lock_key = f"table_{table}" + return self._get_file_lock(table_lock_key, timeout_seconds) + def _ensureTableDirectory(self, table: str) -> bool: """Ensures the table directory exists.""" if table == self._systemTableName: @@ -145,7 +150,9 @@ class DatabaseConnector: return False def _loadTableMetadata(self, table: str) -> Dict[str, Any]: - """Loads table metadata (list of record IDs) without loading actual records.""" + """Loads table metadata (list of record IDs) without loading actual records. + NOTE: This method is safe to call without additional locking. + """ if table in self._tableMetadataCache: return self._tableMetadataCache[table] @@ -159,7 +166,7 @@ class DatabaseConnector: try: if os.path.exists(tablePath): for fileName in os.listdir(tablePath): - if fileName.endswith('.json'): + if fileName.endswith('.json') and fileName != '_metadata.json': recordId = fileName[:-5] # Remove .json extension metadata["recordIds"].append(recordId) @@ -183,17 +190,23 @@ class DatabaseConnector: return None def _saveRecord(self, table: str, recordId: str, record: Dict[str, Any]) -> bool: - """Saves a single record to the table.""" + """Saves a single record to the table with atomic metadata operations.""" recordPath = self._getRecordPath(table, recordId) - lock = self._get_file_lock(recordPath) + record_lock = self._get_file_lock(recordPath) + table_lock = self._get_table_lock(table) try: - # Acquire lock with timeout - if not lock.acquire(timeout=30): # 30 second timeout - raise TimeoutError(f"Could not acquire lock for {recordPath} within 30 seconds") + # Acquire both locks with timeout - record lock first, then table lock + if not record_lock.acquire(timeout=30): + raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds") + + if not table_lock.acquire(timeout=30): + record_lock.release() + raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds") # Record lock acquisition time self._lock_timeouts[recordPath] = time.time() + self._lock_timeouts[f"table_{table}"] = time.time() # Ensure table directory exists if not self._ensureTableDirectory(table): @@ -239,14 +252,14 @@ class DatabaseConnector: # Atomic move from temp to final location os.replace(tempPath, recordPath) - # Update metadata + # ATOMIC: Update metadata while holding both locks metadata = self._loadTableMetadata(table) if recordId not in metadata["recordIds"]: metadata["recordIds"].append(recordId) metadata["recordIds"].sort() self._saveTableMetadata(table, metadata) - # Update cache if it exists + # Update cache if it exists (also protected by table lock) if table in self._tablesCache: # Find and update existing record or append new one found = False @@ -272,14 +285,22 @@ class DatabaseConnector: return False finally: - # ALWAYS release lock, even on error + # ALWAYS release both locks, even on error try: - if lock.locked(): - lock.release() + if table_lock.locked(): + table_lock.release() + if f"table_{table}" in self._lock_timeouts: + del self._lock_timeouts[f"table_{table}"] + except Exception as release_error: + logger.error(f"Error releasing table lock for {table}: {release_error}") + + try: + if record_lock.locked(): + record_lock.release() if recordPath in self._lock_timeouts: del self._lock_timeouts[recordPath] except Exception as release_error: - logger.error(f"Error releasing lock for {recordPath}: {release_error}") + logger.error(f"Error releasing record lock for {recordPath}: {release_error}") def _loadTable(self, table: str) -> List[Dict[str, Any]]: """Loads all records from a table folder.""" @@ -403,40 +424,21 @@ class DatabaseConnector: def _saveTableMetadata(self, table: str, metadata: Dict[str, Any]) -> bool: - """Saves table metadata to a metadata file.""" + """Saves table metadata to a metadata file. + NOTE: This method assumes the caller already holds the table lock. + """ try: # Create metadata file path metadataPath = os.path.join(self._getTablePath(table), "_metadata.json") - # Get lock for metadata file - lock = self._get_file_lock(metadataPath) + # Save metadata (caller should already hold table lock) + with open(metadataPath, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) - try: - # Acquire lock with timeout - if not lock.acquire(timeout=30): - raise TimeoutError(f"Could not acquire lock for metadata {metadataPath} within 30 seconds") - - # Record lock acquisition time - self._lock_timeouts[metadataPath] = time.time() - - # Save metadata - with open(metadataPath, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - - # Update cache - self._tableMetadataCache[table] = metadata - - return True - - finally: - # ALWAYS release lock - try: - if lock.locked(): - lock.release() - if metadataPath in self._lock_timeouts: - del self._lock_timeouts[metadataPath] - except Exception as release_error: - logger.error(f"Error releasing metadata lock for {metadataPath}: {release_error}") + # Update cache + self._tableMetadataCache[table] = metadata + + return True except Exception as e: logger.error(f"Error saving metadata for table {table}: {e}") @@ -582,39 +584,73 @@ class DatabaseConnector: return existingRecord def recordDelete(self, table: str, recordId: str) -> bool: - """Deletes a record from the table.""" - # Load metadata - metadata = self._loadTableMetadata(table) - - if recordId not in metadata["recordIds"]: - return False - - # Check if it's an initial record - initialId = self.getInitialId(table) - if initialId is not None and initialId == recordId: - self._removeInitialId(table) - logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table") - - # Delete the record file + """Deletes a record from the table with atomic metadata operations.""" recordPath = self._getRecordPath(table, recordId) + record_lock = self._get_file_lock(recordPath) + table_lock = self._get_table_lock(table) + try: + # Acquire both locks with timeout - record lock first, then table lock + if not record_lock.acquire(timeout=30): + raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds") + + if not table_lock.acquire(timeout=30): + record_lock.release() + raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds") + + # Record lock acquisition time + self._lock_timeouts[recordPath] = time.time() + self._lock_timeouts[f"table_{table}"] = time.time() + + # Load metadata + metadata = self._loadTableMetadata(table) + + if recordId not in metadata["recordIds"]: + return False + + # Check if it's an initial record + initialId = self.getInitialId(table) + if initialId is not None and initialId == recordId: + self._removeInitialId(table) + logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table") + + # Delete the record file if os.path.exists(recordPath): os.remove(recordPath) - # Update metadata cache + # ATOMIC: Update metadata while holding both locks metadata["recordIds"].remove(recordId) - self._tableMetadataCache[table] = metadata + self._saveTableMetadata(table, metadata) - # Update table cache if it exists + # Update table cache if it exists (also protected by table lock) if table in self._tablesCache: self._tablesCache[table] = [r for r in self._tablesCache[table] if r.get("id") != recordId] return True + else: + return False + except Exception as e: - logger.error(f"Error deleting record file {recordPath}: {e}") + logger.error(f"Error deleting record {recordId} from table {table}: {e}") return False - - return False + + finally: + # ALWAYS release both locks, even on error + try: + if table_lock.locked(): + table_lock.release() + if f"table_{table}" in self._lock_timeouts: + del self._lock_timeouts[f"table_{table}"] + except Exception as release_error: + logger.error(f"Error releasing table lock for {table}: {release_error}") + + try: + if record_lock.locked(): + record_lock.release() + if recordPath in self._lock_timeouts: + del self._lock_timeouts[recordPath] + except Exception as release_error: + logger.error(f"Error releasing record lock for {recordPath}: {release_error}") def getInitialId(self, table: str) -> Optional[str]: """Returns the initial ID for a table.""" diff --git a/modules/connectors/connectorPool.py b/modules/connectors/connectorPool.py new file mode 100644 index 00000000..3137c468 --- /dev/null +++ b/modules/connectors/connectorPool.py @@ -0,0 +1,178 @@ +import threading +import queue +import time +import logging +from typing import Optional, Dict, Any +from .connectorDbJson import DatabaseConnector + +logger = logging.getLogger(__name__) + +class DatabaseConnectorPool: + """ + A connection pool for DatabaseConnector instances to manage resources efficiently + and ensure proper isolation between users. + """ + + def __init__(self, max_connections: int = 100, max_idle_time: int = 300): + """ + Initialize the connection pool. + + Args: + max_connections: Maximum number of connections in the pool + max_idle_time: Maximum idle time in seconds before connection is considered stale + """ + self.max_connections = max_connections + self.max_idle_time = max_idle_time + self._pool = queue.Queue(maxsize=max_connections) + self._created_connections = 0 + self._lock = threading.Lock() + self._connection_times = {} # Track when connections were created + + def _create_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None, + dbPassword: str = None, userId: str = None) -> DatabaseConnector: + """Create a new DatabaseConnector instance.""" + with self._lock: + if self._created_connections >= self.max_connections: + raise RuntimeError(f"Maximum connections ({self.max_connections}) exceeded") + + self._created_connections += 1 + logger.debug(f"Creating new database connector (total: {self._created_connections})") + + connector = DatabaseConnector( + dbHost=dbHost, + dbDatabase=dbDatabase, + dbUser=dbUser, + dbPassword=dbPassword, + userId=userId + ) + + # Track creation time + connector_id = id(connector) + self._connection_times[connector_id] = time.time() + + return connector + + def get_connector(self, dbHost: str, dbDatabase: str, dbUser: str = None, + dbPassword: str = None, userId: str = None) -> DatabaseConnector: + """ + Get a database connector from the pool or create a new one. + + Args: + dbHost: Database host path + dbDatabase: Database name + dbUser: Database user (optional) + dbPassword: Database password (optional) + userId: User ID for context (optional) + + Returns: + DatabaseConnector instance + """ + try: + # Try to get an existing connector from the pool + connector = self._pool.get_nowait() + + # Check if connector is stale + connector_id = id(connector) + if connector_id in self._connection_times: + idle_time = time.time() - self._connection_times[connector_id] + if idle_time > self.max_idle_time: + logger.debug(f"Connector {connector_id} is stale (idle: {idle_time}s), creating new one") + # Remove stale connector from tracking + if connector_id in self._connection_times: + del self._connection_times[connector_id] + # Create new connector + return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId) + + # Update user context if provided + if userId is not None: + connector.updateContext(userId) + + logger.debug(f"Reusing existing connector {connector_id}") + return connector + + except queue.Empty: + # Pool is empty, create new connector + return self._create_connector(dbHost, dbDatabase, dbUser, dbPassword, userId) + + def return_connector(self, connector: DatabaseConnector) -> None: + """ + Return a connector to the pool for reuse. + + Args: + connector: DatabaseConnector instance to return + """ + try: + # Update connection time + connector_id = id(connector) + self._connection_times[connector_id] = time.time() + + # Try to return to pool + self._pool.put_nowait(connector) + logger.debug(f"Returned connector {connector_id} to pool") + + except queue.Full: + # Pool is full, discard connector + logger.debug(f"Pool full, discarding connector {id(connector)}") + with self._lock: + self._created_connections -= 1 + if id(connector) in self._connection_times: + del self._connection_times[id(connector)] + + def cleanup_stale_connections(self) -> int: + """ + Clean up stale connections from the pool. + + Returns: + Number of connections cleaned up + """ + cleaned = 0 + current_time = time.time() + + # Check all tracked connections + stale_connectors = [] + for connector_id, creation_time in list(self._connection_times.items()): + if current_time - creation_time > self.max_idle_time: + stale_connectors.append(connector_id) + + # Remove stale connections from tracking + for connector_id in stale_connectors: + if connector_id in self._connection_times: + del self._connection_times[connector_id] + cleaned += 1 + + logger.debug(f"Cleaned up {cleaned} stale connections") + return cleaned + + def get_stats(self) -> Dict[str, Any]: + """Get pool statistics.""" + with self._lock: + return { + "max_connections": self.max_connections, + "created_connections": self._created_connections, + "available_connections": self._pool.qsize(), + "tracked_connections": len(self._connection_times) + } + +# Global pool instance +_connector_pool = None +_pool_lock = threading.Lock() + +def get_connector_pool() -> DatabaseConnectorPool: + """Get the global connector pool instance.""" + global _connector_pool + if _connector_pool is None: + with _pool_lock: + if _connector_pool is None: + _connector_pool = DatabaseConnectorPool() + return _connector_pool + +def get_connector(dbHost: str, dbDatabase: str, dbUser: str = None, + dbPassword: str = None, userId: str = None) -> DatabaseConnector: + """Get a database connector from the global pool.""" + pool = get_connector_pool() + return pool.get_connector(dbHost, dbDatabase, dbUser, dbPassword, userId) + +def return_connector(connector: DatabaseConnector) -> None: + """Return a database connector to the global pool.""" + pool = get_connector_pool() + pool.return_connector(connector) diff --git a/modules/interfaces/interfaceAppObjects.py b/modules/interfaces/interfaceAppObjects.py index e9683158..25183fe2 100644 --- a/modules/interfaces/interfaceAppObjects.py +++ b/modules/interfaces/interfaceAppObjects.py @@ -13,6 +13,7 @@ from passlib.context import CryptContext import uuid from modules.connectors.connectorDbJson import DatabaseConnector +from modules.connectors.connectorPool import get_connector, return_connector from modules.shared.configuration import APP_CONFIG from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp from modules.interfaces.interfaceAppAccess import AppAccess @@ -79,8 +80,16 @@ class AppObjects: # Update database context self.db.updateContext(self.userId) + def __del__(self): + """Cleanup method to return connector to pool.""" + if hasattr(self, 'db') and self.db is not None: + try: + return_connector(self.db) + except Exception as e: + logger.error(f"Error returning connector to pool: {e}") + def _initializeDatabase(self): - """Initializes the database connection.""" + """Initializes the database connection using connection pool.""" try: # Get configuration values with defaults dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data") @@ -91,14 +100,16 @@ class AppObjects: # Ensure the database directory exists os.makedirs(dbHost, exist_ok=True) - self.db = DatabaseConnector( + # Get connector from pool with user context + self.db = get_connector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, - dbPassword=dbPassword + dbPassword=dbPassword, + userId=self.userId ) - logger.info("Database initialized successfully") + logger.info(f"Database initialized successfully for user {self.userId}") except Exception as e: logger.error(f"Failed to initialize database: {str(e)}") raise diff --git a/modules/interfaces/interfaceChatObjects.py b/modules/interfaces/interfaceChatObjects.py index 95ebb3a5..239e76bd 100644 --- a/modules/interfaces/interfaceChatObjects.py +++ b/modules/interfaces/interfaceChatObjects.py @@ -19,6 +19,7 @@ from modules.interfaces.interfaceAppModel import User # DYNAMIC PART: Connectors to the Interface from modules.connectors.connectorDbJson import DatabaseConnector +from modules.connectors.connectorPool import get_connector, return_connector from modules.shared.timezoneUtils import get_utc_timestamp # Basic Configurations @@ -72,6 +73,14 @@ class ChatObjects: # Update database context self.db.updateContext(self.userId) + + def __del__(self): + """Cleanup method to return connector to pool.""" + if hasattr(self, 'db') and self.db is not None: + try: + return_connector(self.db) + except Exception as e: + logger.error(f"Error returning connector to pool: {e}") logger.debug(f"User context set: userId={self.userId}, mandateId={self.mandateId}") @@ -87,11 +96,12 @@ class ChatObjects: # Ensure the database directory exists os.makedirs(dbHost, exist_ok=True) - self.db = DatabaseConnector( + self.db = get_connector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, - dbPassword=dbPassword + dbPassword=dbPassword, + userId=self.userId ) logger.info("Database initialized successfully") diff --git a/modules/interfaces/interfaceComponentObjects.py b/modules/interfaces/interfaceComponentObjects.py index 59b10ddf..36058cc7 100644 --- a/modules/interfaces/interfaceComponentObjects.py +++ b/modules/interfaces/interfaceComponentObjects.py @@ -18,6 +18,7 @@ from modules.interfaces.interfaceAppModel import User # DYNAMIC PART: Connectors to the Interface from modules.connectors.connectorDbJson import DatabaseConnector +from modules.connectors.connectorPool import get_connector, return_connector # Basic Configurations from modules.shared.configuration import APP_CONFIG @@ -87,6 +88,14 @@ class ComponentObjects: # Update database context self.db.updateContext(self.userId) + + def __del__(self): + """Cleanup method to return connector to pool.""" + if hasattr(self, 'db') and self.db is not None: + try: + return_connector(self.db) + except Exception as e: + logger.error(f"Error returning connector to pool: {e}") logger.debug(f"User context set: userId={self.userId}") @@ -102,11 +111,12 @@ class ComponentObjects: # Ensure the database directory exists os.makedirs(dbHost, exist_ok=True) - self.db = DatabaseConnector( + self.db = get_connector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, - dbPassword=dbPassword + dbPassword=dbPassword, + userId=self.userId if hasattr(self, 'userId') else None ) logger.info("Database initialized successfully") diff --git a/notes/changelog.txt b/notes/changelog.txt index affd7a3d..60af5270 100644 --- a/notes/changelog.txt +++ b/notes/changelog.txt @@ -2,7 +2,9 @@ TODO # System -- sharepoint to fix +- database +- db initialization as separate function to create root mandate, then sysadmin with hashed passwords --> using the connector according to env configuration +- config page for: db reset - document handling centralized - ai handling centralized - neutralizer to activate AND put back placeholders to the returned data diff --git a/test_concurrency_fixes.py b/test_concurrency_fixes.py new file mode 100644 index 00000000..4613b999 --- /dev/null +++ b/test_concurrency_fixes.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +Test script to verify concurrency improvements in DatabaseConnector. +This script simulates multiple users accessing the database simultaneously. +""" + +import os +import sys +import time +import threading +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +import tempfile +import shutil + +# Add the gateway directory to the path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from modules.connectors.connectorDbJson import DatabaseConnector +from modules.connectors.connectorPool import get_connector, return_connector + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def test_concurrent_record_operations(): + """Test concurrent record creation, modification, and deletion.""" + + # Create temporary database directory + temp_dir = tempfile.mkdtemp() + db_host = temp_dir + db_database = "test_db" + + try: + logger.info("Starting concurrency test...") + + def user_operation(user_id: int, operation_count: int = 10): + """Simulate a user performing database operations.""" + try: + # Get a dedicated connector for this user + db = get_connector( + dbHost=db_host, + dbDatabase=db_database, + userId=f"user_{user_id}" + ) + + results = [] + + for i in range(operation_count): + # Create a record + record = { + "id": f"user_{user_id}_record_{i}", + "data": f"User {user_id} data {i}", + "timestamp": time.time() + } + + # Create record + created = db.recordCreate("test_table", record) + results.append(f"Created: {created['id']}") + + # Modify record + record["data"] = f"Modified by user {user_id} - {i}" + modified = db.recordModify("test_table", record["id"], record) + results.append(f"Modified: {modified['id']}") + + # Small delay to increase chance of race conditions + time.sleep(0.001) + + # Return connector to pool + return_connector(db) + + return results + + except Exception as e: + logger.error(f"User {user_id} error: {e}") + return [f"Error: {e}"] + + # Test with multiple concurrent users + num_users = 20 + operations_per_user = 5 + + logger.info(f"Testing with {num_users} users, {operations_per_user} operations each") + + start_time = time.time() + + with ThreadPoolExecutor(max_workers=num_users) as executor: + # Submit all user operations + futures = [ + executor.submit(user_operation, user_id, operations_per_user) + for user_id in range(num_users) + ] + + # Collect results + all_results = [] + for future in as_completed(futures): + try: + result = future.result() + all_results.extend(result) + except Exception as e: + logger.error(f"Future error: {e}") + + end_time = time.time() + + # Verify data integrity + db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier") + + # Check that all records exist and are consistent + all_records = db.getRecordset("test_table") + expected_count = num_users * operations_per_user + + logger.info(f"Expected records: {expected_count}") + logger.info(f"Actual records: {len(all_records)}") + logger.info(f"Test completed in {end_time - start_time:.2f} seconds") + + # Check for data consistency + record_ids = set(record["id"] for record in all_records) + expected_ids = set(f"user_{user_id}_record_{i}" for user_id in range(num_users) for i in range(operations_per_user)) + + missing_ids = expected_ids - record_ids + extra_ids = record_ids - expected_ids + + if missing_ids: + logger.error(f"Missing records: {missing_ids}") + if extra_ids: + logger.error(f"Extra records: {extra_ids}") + + # Check for data corruption (records with wrong user data) + corrupted_records = [] + for record in all_records: + record_id = record["id"] + user_id = int(record_id.split("_")[1]) + if f"Modified by user {user_id}" not in record.get("data", ""): + corrupted_records.append(record_id) + + if corrupted_records: + logger.error(f"Corrupted records: {corrupted_records}") + + success = len(missing_ids) == 0 and len(extra_ids) == 0 and len(corrupted_records) == 0 + + if success: + logger.info("✅ Concurrency test PASSED - No data corruption detected") + else: + logger.error("❌ Concurrency test FAILED - Data corruption detected") + + return success + + finally: + # Cleanup + try: + shutil.rmtree(temp_dir) + logger.info("Cleaned up temporary directory") + except Exception as e: + logger.error(f"Error cleaning up: {e}") + +def test_metadata_consistency(): + """Test that metadata operations are atomic.""" + + temp_dir = tempfile.mkdtemp() + db_host = temp_dir + db_database = "test_metadata" + + try: + logger.info("Testing metadata consistency...") + + def concurrent_metadata_operations(user_id: int): + """Perform concurrent metadata operations.""" + db = get_connector( + dbHost=db_host, + dbDatabase=db_database, + userId=f"user_{user_id}" + ) + + try: + # Create multiple records rapidly + for i in range(10): + record = { + "id": f"user_{user_id}_meta_{i}", + "data": f"Metadata test {user_id}-{i}" + } + db.recordCreate("metadata_test", record) + time.sleep(0.001) # Small delay + + return True + except Exception as e: + logger.error(f"Metadata test error for user {user_id}: {e}") + return False + finally: + return_connector(db) + + # Run concurrent metadata operations + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(concurrent_metadata_operations, i) for i in range(10)] + results = [future.result() for future in as_completed(futures)] + + # Verify metadata consistency + db = get_connector(dbHost=db_host, dbDatabase=db_database, userId="verifier") + records = db.getRecordset("metadata_test") + + # Check that metadata is consistent + metadata = db._loadTableMetadata("metadata_test") + expected_count = len(records) + actual_count = len(metadata["recordIds"]) + + logger.info(f"Expected record count: {expected_count}") + logger.info(f"Metadata record count: {actual_count}") + + success = expected_count == actual_count + + if success: + logger.info("✅ Metadata consistency test PASSED") + else: + logger.error("❌ Metadata consistency test FAILED") + + return success + + finally: + try: + shutil.rmtree(temp_dir) + except Exception as e: + logger.error(f"Error cleaning up: {e}") + +if __name__ == "__main__": + logger.info("Starting concurrency tests...") + + # Test 1: Concurrent record operations + test1_passed = test_concurrent_record_operations() + + # Test 2: Metadata consistency + test2_passed = test_metadata_consistency() + + # Overall result + if test1_passed and test2_passed: + logger.info("🎉 All concurrency tests PASSED!") + sys.exit(0) + else: + logger.error("💥 Some concurrency tests FAILED!") + sys.exit(1)