From 2bb65c230369b3b54e8f8c06e32d8296d1c8e100 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Sun, 17 May 2026 20:38:37 +0200 Subject: [PATCH 1/6] db connection pooling and rag limit transparency --- app.py | 11 +- modules/connectors/connectorDbPostgre.py | 1180 ++++++++++------- .../realEstate/interfaceFeatureRealEstate.py | 8 +- modules/interfaces/interfaceDbBilling.py | 18 +- modules/interfaces/interfaceDbKnowledge.py | 16 + modules/interfaces/interfaceDbManagement.py | 32 +- modules/interfaces/interfaceRbac.py | 12 +- modules/routes/routeHelpers.py | 2 +- modules/routes/routeRagInventory.py | 65 +- modules/routes/routeWorkflowDashboard.py | 4 +- .../coreTools/_featureSubAgentTools.py | 15 +- .../subConnectorSyncClickup.py | 43 +- .../subConnectorSyncGdrive.py | 43 +- .../subConnectorSyncKdrive.py | 46 +- .../subConnectorSyncSharepoint.py | 56 +- modules/shared/configuration.py | 63 +- modules/shared/dbMultiTenantOptimizations.py | 101 +- modules/shared/gdprDeletion.py | 178 ++- scripts/stage0_filefolder_schema_check.py | 2 +- .../test_connectorDbPostgre_failLoud.py | 82 +- .../test_connectorDbPostgre_pool.py | 304 +++++ tests/unit/interfaces/test_folderRbac.py | 10 + tests/unit/routes/test_folder_crud.py | 10 + 23 files changed, 1519 insertions(+), 782 deletions(-) create mode 100644 tests/unit/connectors/test_connectorDbPostgre_pool.py diff --git a/app.py b/app.py index 7a4ed4d4..d94c7dd5 100644 --- a/app.py +++ b/app.py @@ -438,7 +438,16 @@ async def lifespan(app: FastAPI): logger.error(f"Feature '{featureName}' failed to stop: {e}") except Exception as e: logger.warning(f"Could not shutdown feature containers: {e}") - + + # --- Close all PostgreSQL connection pools --- + # Must run LAST: feature `onStop` hooks may still issue DB calls during + # shutdown. Once we tear down the pools, no more borrows are possible. + try: + from modules.connectors.connectorDbPostgre import closeAllPools + closeAllPools() + except Exception as e: + logger.warning(f"Closing DB connection pools failed: {e}") + logger.info("Application has been shut down") diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index a6893396..f1a34f70 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -2,9 +2,12 @@ # All rights reserved. import contextvars import re +import time import psycopg2 import psycopg2.extras +import psycopg2.pool import logging +from contextlib import contextmanager from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type import uuid from pydantic import BaseModel, Field @@ -44,24 +47,6 @@ class DatabaseQueryError(RuntimeError): self.original = original -def _rollbackQuietly(connection) -> None: - """Restore the connection state after a failed query. - - Postgres puts the connection in an error state after any failed - statement; subsequent queries on the same connection raise - ``InFailedSqlTransaction`` until we rollback. We swallow rollback - errors because the original query error is what the caller should - see — a secondary rollback failure typically means the connection - is gone and will be reopened on the next ``_ensure_connection``. - """ - if connection is None: - return - try: - connection.rollback() - except Exception: - pass - - class SystemTable(PowerOnModel): """Data model for system table entries""" @@ -203,9 +188,174 @@ def _quotePgIdent(name: str) -> str: return '"' + str(name).replace('"', '""') + '"' -# Cache connectors by (host, database, port) to avoid duplicate inits for same database. -# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via -# contextvars to avoid races when concurrent requests share the same connector. +# --------------------------------------------------------------------------- +# Connection pool registry +# --------------------------------------------------------------------------- +# psycopg2 connections are NOT thread-safe; sharing one connection across the +# FastAPI threadpool (sync `def` routes) or across multiple async tasks that +# happen to be active simultaneously results in either +# `OperationalError: another command is already in progress` or — far worse — +# an unbounded hang in `recv()` on the connection socket, because two cursors +# wait for one server response. +# +# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity +# (`host`, `db`, `port`). Every DB call borrows a connection from the pool, +# runs its query, and returns the connection. The pool itself guarantees that +# no two callers share the same psycopg2 connection at the same time. +# +# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted +# instead of hanging forever and poisoning the connection. `connect_timeout=10` +# prevents indefinite blocking when the Postgres host is unreachable. + +_DEFAULT_POOL_MIN = 2 +_DEFAULT_POOL_MAX = 20 +_STATEMENT_TIMEOUT_MS = 30000 +_CONNECT_TIMEOUT_S = 10 +# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool +# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes +# bursty workloads (50 concurrent route calls against a max=20 pool) fail +# spuriously. `borrowConn()` therefore retries with a small backoff up to +# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up. +_BORROW_WAIT_TIMEOUT_S = 30.0 +_BORROW_WAIT_BACKOFF_S = 0.05 + + +def _resolvePoolMax() -> int: + """Pool size is configurable via `DB_POOL_MAX_CONN` (default 20).""" + try: + return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX)) + except (ValueError, TypeError): + return _DEFAULT_POOL_MAX + + +class _PoolRegistry: + """Process-wide registry of `ThreadedConnectionPool` instances. + + Keyed by `(host, database, port)` so that two `DatabaseConnector` instances + pointing at the same physical database share one pool. Lazy-initialised and + thread-safe. + """ + + _pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {} + _lock = threading.Lock() + + @classmethod + def getPool( + cls, + *, + dbHost: str, + dbDatabase: str, + dbUser: str, + dbPassword: str, + dbPort: int, + ) -> psycopg2.pool.ThreadedConnectionPool: + port = int(dbPort) if dbPort is not None else 5432 + key = (dbHost, dbDatabase, port) + # Fast path: pool exists. + pool = cls._pools.get(key) + if pool is not None: + return pool + # Slow path: create exactly one pool per key, even under contention. + with cls._lock: + pool = cls._pools.get(key) + if pool is not None: + return pool + poolMax = _resolvePoolMax() + options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}" + try: + pool = psycopg2.pool.ThreadedConnectionPool( + _DEFAULT_POOL_MIN, + poolMax, + host=dbHost, + port=port, + database=dbDatabase, + user=dbUser, + password=dbPassword, + client_encoding="utf8", + cursor_factory=psycopg2.extras.RealDictCursor, + connect_timeout=_CONNECT_TIMEOUT_S, + options=options, + ) + except Exception as e: + logger.error( + "Failed to create connection pool for db=%s host=%s: %s", + dbDatabase, dbHost, e, + ) + raise + cls._pools[key] = pool + logger.debug( + "Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)", + dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS, + ) + return pool + + @classmethod + def closeAll(cls) -> None: + """Close every pool. Call once during FastAPI shutdown.""" + with cls._lock: + for key, pool in list(cls._pools.items()): + try: + pool.closeall() + except Exception as e: + logger.warning("Error closing pool %s: %s", key, e) + cls._pools.clear() + logger.info("All database connection pools closed") + + +def closeAllPools() -> None: + """Public entry point for FastAPI lifespan shutdown hook.""" + _PoolRegistry.closeAll() + + +class _ConnectionShim: + """Backward-compatibility stand-in for the old `DatabaseConnector.connection`. + + Old callers used patterns like:: + + db.connection.commit() + if db.connection and not db.connection.closed: ... + + These are now no-ops because the pool owns the connection lifecycle. + Direct cursor access through `self.connection.cursor()` is intentionally + blocked with a clear error so silent breakage is impossible — every such + call site must migrate to `db.borrowCursor()`. + """ + + closed = False + + def __bool__(self) -> bool: + return True + + def commit(self) -> None: + return + + def rollback(self) -> None: + return + + def close(self) -> None: + return + + def cursor(self, *args, **kwargs): + raise RuntimeError( + "DatabaseConnector.connection.cursor() is no longer supported. " + "Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement " + "transactions) so the connection is borrowed from and returned to " + "the pool correctly." + ) + + +_CONNECTION_SHIM = _ConnectionShim() + + +# --------------------------------------------------------------------------- +# Connector cache (lightweight wrappers — actual connections live in the pool) +# --------------------------------------------------------------------------- +# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`, +# interfaces) ask for a connector via `getCachedConnector(...)` and expect to +# get back the same object on subsequent calls. Now that real connection +# multiplexing happens at the pool layer, the cache returns lightweight +# `DatabaseConnector` wrappers — they hold no connection themselves, only the +# DSN params and a reference to the shared pool. _MAX_CACHED_CONNECTORS = 32 _connector_cache: Dict[tuple, "DatabaseConnector"] = {} _connector_cache_order: List[tuple] = [] # FIFO order for eviction @@ -223,22 +373,36 @@ def getCachedConnector( dbPort: int = None, userId: str = None, ) -> "DatabaseConnector": - """Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits. - Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy. + """Return a cached `DatabaseConnector` wrapper for `(host, database, port)`. + + Two layers of caching, both intentional: + + 1. **Pool layer** (`_PoolRegistry`) — owns the actual psycopg2 connections. + Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared + across every wrapper that points at the same physical database. This + is what actually saves Postgres connection slots. + + 2. **Wrapper layer** (this function) — caches the `DatabaseConnector` + Python object. Each wrapper triggers an `initDbSystem()` on first + instantiation (CREATE DATABASE if missing, CREATE TABLE _system, + pool warm-up). Caching the wrapper avoids paying that bootstrap cost + on every request and stops the log from filling with "PostgreSQL + database system initialized" lines. + + `userId` is request-scoped via the `_current_user_id` contextvar so two + concurrent requests sharing the same cached wrapper still produce + correct `sysCreatedBy` / `sysModifiedBy` audit fields. """ port = int(dbPort) if dbPort is not None else 5432 key = (dbHost, dbDatabase, port) with _connector_cache_lock: if key not in _connector_cache: - # Evict oldest if at capacity + # FIFO eviction. Connectors are now lightweight (no per-instance + # connection), so eviction is purely a memory bookkeeping concern; + # the underlying pool stays alive in `_PoolRegistry`. while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order: oldest_key = _connector_cache_order.pop(0) - if oldest_key in _connector_cache: - try: - _connector_cache[oldest_key].close(forceClose=True) - except Exception as e: - logger.warning(f"Error closing evicted connector: {e}") - del _connector_cache[oldest_key] + _connector_cache.pop(oldest_key, None) _connector_cache[key] = DatabaseConnector( dbHost=dbHost, dbDatabase=dbDatabase, @@ -282,34 +446,38 @@ class DatabaseConnector: # Set userId (default to empty string if None) self.userId = userId if userId is not None else "" - # Initialize database system first (creates database if needed) - self.connection = None + # No per-instance connection any more — real connections live in the + # shared `_PoolRegistry` pool. `_isCachedShared` is retained because + # `close(forceClose=False)` callers (interface __del__) still ask. self._isCachedShared = False - self.initDbSystem() - # No caching needed with proper database - PostgreSQL handles performance - - # Thread safety - self._lock = threading.Lock() - - # pgvector extension state + # pgvector extension state (cached per connector instance — cheap) self._vectorExtensionEnabled = False - # Initialize system table + # System table bootstrap: create database, system table, ensure metadata. self._systemTableName = "_system" + self.initDbSystem() self._initializeSystemTable() def initDbSystem(self): - """Initialize the database system - creates database and tables.""" - try: - # Create database if it doesn't exist - self._create_database_if_not_exists() + """Bootstrap the physical database and the `_system` metadata table. - # Create tables + Uses a short-lived autocommit connection (NOT the pool) because + `CREATE DATABASE` cannot run inside a transaction block. Also warms up + the pool by acquiring it for this `(host, db, port)` once. + """ + try: + self._create_database_if_not_exists() self._create_tables() - # Establish connection to the database - self._connect() + # Warm the pool so the first request doesn't pay for socket setup. + _PoolRegistry.getPool( + dbHost=self.dbHost, + dbDatabase=self.dbDatabase, + dbUser=self.dbUser, + dbPassword=self.dbPassword, + dbPort=self.dbPort, + ) logger.debug( "PostgreSQL database system initialized (db=%s, host=%s, port=%s)", @@ -319,10 +487,121 @@ class DatabaseConnector: logger.error(f"FATAL ERROR: Database system initialization failed: {e}") raise - def _create_database_if_not_exists(self): - """Create the database if it doesn't exist.""" + @property + def connection(self) -> "_ConnectionShim": + """Backward-compat shim — see `_ConnectionShim` docstring.""" + return _CONNECTION_SHIM + + def _ensure_connection(self) -> None: + """No-op for backward compatibility. + + Previously this method tested-or-reconnected the per-instance socket. + With pooling, the `ThreadedConnectionPool` re-establishes broken + connections lazily on the next `getconn()`. Kept as a no-op so that + legacy call sites continue to compile. + """ + return + + @contextmanager + def borrowCursor(self): + """Borrow a cursor for one short SQL block. + + Convenience wrapper around `borrowConn()` for the common pattern: + + with db.borrowCursor() as cursor: + cursor.execute(...) + rows = cursor.fetchall() + + Replaces the legacy `with db.connection.cursor() as cursor:` pattern. + Commit/rollback and pool return are handled automatically — callers + must NOT call `.commit()`/`.rollback()` on the cursor themselves. + """ + with self.borrowConn() as conn: + with conn.cursor() as cursor: + yield cursor + + @contextmanager + def borrowConn(self): + """Borrow a connection from the pool for the duration of the block. + + Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises + `psycopg2.pool.PoolError` immediately when the pool is at its `maxconn` + limit — it does NOT block. We wrap that with a bounded busy-wait so + bursty workloads (e.g. 50 concurrent route calls against a max=20 + pool) queue up instead of failing. If no connection becomes available + within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a + genuinely deadlocked pool surfaces as a real error. + + On normal exit the current transaction is committed (idempotent for + read-only queries — leaves the connection in a clean state for the + next borrower). On exception the transaction is rolled back and the + exception propagates. The connection is **always** returned to the + pool, even when commit/rollback fails — otherwise a single bad query + would leak a slot and eventually exhaust the pool. + """ + pool = _PoolRegistry.getPool( + dbHost=self.dbHost, + dbDatabase=self.dbDatabase, + dbUser=self.dbUser, + dbPassword=self.dbPassword, + dbPort=self.dbPort, + ) + conn = self._acquireConn(pool) + try: + yield conn + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + else: + # Best-effort commit so the connection goes back to the pool with + # no in-flight transaction. Failure here is non-fatal — the pool + # will re-establish the socket on the next `getconn()` if needed. + try: + conn.commit() + except Exception: + try: + conn.rollback() + except Exception: + pass + finally: + try: + pool.putconn(conn) + except Exception as e: + logger.warning("Failed to return connection to pool: %s", e) + + @staticmethod + def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool): + """Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`. + + psycopg2's pool throws on exhaustion instead of queueing — this helper + polls with a short backoff so callers see queue semantics. + """ + deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S + attempt = 0 + while True: + try: + return pool.getconn() + except psycopg2.pool.PoolError as e: + attempt += 1 + if time.monotonic() >= deadline: + logger.error( + "Connection pool exhausted after %.1fs wait (%d retries)", + _BORROW_WAIT_TIMEOUT_S, attempt, + ) + raise + time.sleep(_BORROW_WAIT_BACKOFF_S) + + def _create_database_if_not_exists(self): + """Create the database if it doesn't exist. + + Uses an autocommit connection on the `postgres` admin DB because + `CREATE DATABASE` cannot run inside a transaction block — so this + path intentionally does NOT use the pool. + """ try: - # Use the configured user for database creation conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, @@ -330,23 +609,21 @@ class DatabaseConnector: user=self.dbUser, password=self.dbPassword, client_encoding="utf8", + connect_timeout=_CONNECT_TIMEOUT_S, ) conn.autocommit = True - - with conn.cursor() as cursor: - # Check if database exists - cursor.execute( - "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) - ) - exists = cursor.fetchone() - - if not exists: - # Create database with proper quoting for names with hyphens - quoted_db_name = f'"{self.dbDatabase}"' - cursor.execute(f"CREATE DATABASE {quoted_db_name}") - logger.info(f"Created database: {self.dbDatabase}") - - conn.close() + try: + with conn.cursor() as cursor: + cursor.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) + ) + exists = cursor.fetchone() + if not exists: + quoted_db_name = f'"{self.dbDatabase}"' + cursor.execute(f"CREATE DATABASE {quoted_db_name}") + logger.info(f"Created database: {self.dbDatabase}") + finally: + conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create database: {e}") @@ -356,9 +633,12 @@ class DatabaseConnector: ) def _create_tables(self): - """Create only the system table - application tables are created by interfaces.""" + """Create the `_system` table. + + Uses a short-lived autocommit connection (not the pool) — runs exactly + once at connector creation. + """ try: - # Use the configured user for table creation conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, @@ -366,23 +646,24 @@ class DatabaseConnector: user=self.dbUser, password=self.dbPassword, client_encoding="utf8", + connect_timeout=_CONNECT_TIMEOUT_S, ) conn.autocommit = True - - with conn.cursor() as cursor: - # Create only the system table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS _system ( - id SERIAL PRIMARY KEY, - table_name VARCHAR(255) UNIQUE NOT NULL, - initial_id VARCHAR(255) NOT NULL, - "sysCreatedAt" DOUBLE PRECISION, - "sysCreatedBy" VARCHAR(255), - "sysModifiedAt" DOUBLE PRECISION, - "sysModifiedBy" VARCHAR(255) - ) - """) - conn.close() + try: + with conn.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS _system ( + id SERIAL PRIMARY KEY, + table_name VARCHAR(255) UNIQUE NOT NULL, + initial_id VARCHAR(255) NOT NULL, + "sysCreatedAt" DOUBLE PRECISION, + "sysCreatedBy" VARCHAR(255), + "sysModifiedAt" DOUBLE PRECISION, + "sysModifiedBy" VARCHAR(255) + ) + """) + finally: + conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create system table: {e}") @@ -391,67 +672,26 @@ class DatabaseConnector: ) raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}") - def _connect(self): - """Establish connection to PostgreSQL database.""" - try: - # Use configured user for main connection with proper parameter handling - self.connection = psycopg2.connect( - host=self.dbHost, - port=self.dbPort, - database=self.dbDatabase, - user=self.dbUser, - password=self.dbPassword, - client_encoding="utf8", - cursor_factory=psycopg2.extras.RealDictCursor, - ) - self.connection.autocommit = False # Use transactions - except Exception as e: - logger.error(f"Failed to connect to PostgreSQL: {e}") - raise - - def _ensure_connection(self): - """Ensure database connection is alive, reconnect if necessary.""" - try: - if self.connection is None or self.connection.closed: - self._connect() - else: - # Test connection with a simple query - with self.connection.cursor() as cursor: - cursor.execute("SELECT 1") - except Exception as e: - logger.warning(f"Connection lost, reconnecting: {e}") - self._connect() - def _initializeSystemTable(self): """Initializes the system table if it doesn't exist yet.""" try: - # First ensure the system table exists self._ensureTableExists(SystemTable) - - with self.connection.cursor() as cursor: - # Check if system table has any data - cursor.execute('SELECT COUNT(*) FROM "_system"') - row = cursor.fetchone() - count = row["count"] if row else 0 - - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute('SELECT COUNT(*) FROM "_system"') + cursor.fetchone() # noqa: just verifies table is readable except Exception as e: logger.error(f"Error initializing system table: {e}") - self.connection.rollback() raise def _loadSystemTable(self) -> Dict[str, str]: """Loads the system table with the initial IDs.""" try: - with self.connection.cursor() as cursor: - cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') - rows = cursor.fetchall() - - system_data = {} - for row in rows: - system_data[row["table_name"]] = row["initial_id"] - - return system_data + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') + rows = cursor.fetchall() + return {row["table_name"]: row["initial_id"] for row in rows} except Exception as e: logger.error(f"Error loading system table: {e}") return {} @@ -459,75 +699,65 @@ class DatabaseConnector: def _saveSystemTable(self, data: Dict[str, str]) -> bool: """Saves the system table with the initial IDs.""" try: - with self.connection.cursor() as cursor: - # Clear existing data - cursor.execute('DELETE FROM "_system"') - - # Insert new data - for table_name, initial_id in data.items(): - cursor.execute( - """ - INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt") - VALUES (%s, %s, %s) - """, - (table_name, initial_id, getUtcTimestamp()), - ) - - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute('DELETE FROM "_system"') + for table_name, initial_id in data.items(): + cursor.execute( + """ + INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt") + VALUES (%s, %s, %s) + """, + (table_name, initial_id, getUtcTimestamp()), + ) return True except Exception as e: logger.error(f"Error saving system table: {e}") - self.connection.rollback() return False def _ensureSystemTableExists(self) -> bool: """Ensures the system table exists, creates it if it doesn't.""" try: - self._ensure_connection() - - with self.connection.cursor() as cursor: - # Check if system table exists - cursor.execute( - "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", - (self._systemTableName,), - ) - exists = cursor.fetchone()["count"] > 0 - - if not exists: - # Create system table - cursor.execute(f""" - CREATE TABLE "{self._systemTableName}" ( - "table_name" VARCHAR(255) PRIMARY KEY, - "initial_id" VARCHAR(255), - "sysCreatedAt" DOUBLE PRECISION, - "sysCreatedBy" VARCHAR(255), - "sysModifiedAt" DOUBLE PRECISION, - "sysModifiedBy" VARCHAR(255) - ) - """) - logger.info("System table created successfully") - else: - # Check if we need to add missing columns to existing table + with self.borrowConn() as conn: + with conn.cursor() as cursor: cursor.execute( - """ - SELECT column_name FROM information_schema.columns - WHERE table_name = %s AND table_schema = 'public' - """, + "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,), ) - existing_columns = [row["column_name"] for row in cursor.fetchall()] + exists = cursor.fetchone()["count"] > 0 - for sys_col, sys_sql in [ - ("sysCreatedAt", "DOUBLE PRECISION"), - ("sysCreatedBy", "VARCHAR(255)"), - ("sysModifiedAt", "DOUBLE PRECISION"), - ("sysModifiedBy", "VARCHAR(255)"), - ]: - if sys_col not in existing_columns: - cursor.execute( - f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}' + if not exists: + cursor.execute(f""" + CREATE TABLE "{self._systemTableName}" ( + "table_name" VARCHAR(255) PRIMARY KEY, + "initial_id" VARCHAR(255), + "sysCreatedAt" DOUBLE PRECISION, + "sysCreatedBy" VARCHAR(255), + "sysModifiedAt" DOUBLE PRECISION, + "sysModifiedBy" VARCHAR(255) ) + """) + logger.info("System table created successfully") + else: + cursor.execute( + """ + SELECT column_name FROM information_schema.columns + WHERE table_name = %s AND table_schema = 'public' + """, + (self._systemTableName,), + ) + existing_columns = [row["column_name"] for row in cursor.fetchall()] + for sys_col, sys_sql in [ + ("sysCreatedAt", "DOUBLE PRECISION"), + ("sysCreatedBy", "VARCHAR(255)"), + ("sysModifiedAt", "DOUBLE PRECISION"), + ("sysModifiedBy", "VARCHAR(255)"), + ]: + if sys_col not in existing_columns: + cursor.execute( + f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}' + ) return True except Exception as e: logger.error(f"Error ensuring system table exists: {e}") @@ -542,123 +772,113 @@ class DatabaseConnector: return self._ensureSystemTableExists() try: - self._ensure_connection() - - with self.connection.cursor() as cursor: - # Check if table exists by querying information_schema with case-insensitive search - cursor.execute( - """ - SELECT COUNT(*) FROM information_schema.tables - WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' - """, - (table,), - ) - exists = cursor.fetchone()["count"] > 0 - - if not exists: - # Create table from Pydantic model - self._create_table_from_model(cursor, table, model_class) - logger.info( - f"Created table '{table}' with columns from Pydantic model" + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """ + SELECT COUNT(*) FROM information_schema.tables + WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' + """, + (table,), ) - else: - # Table exists: ensure all columns from model are present (simple additive migration) - try: - cursor.execute( - """ - SELECT column_name, data_type - FROM information_schema.columns - WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' - """, - (table,), + exists = cursor.fetchone()["count"] > 0 + + if not exists: + self._create_table_from_model(cursor, table, model_class) + logger.info( + f"Created table '{table}' with columns from Pydantic model" ) - existing_column_rows = cursor.fetchall() - existing_columns = { - row["column_name"] for row in existing_column_rows - } - existing_column_types = { - row["column_name"]: (row["data_type"] or "").lower() - for row in existing_column_rows - } + else: + # Table exists: ensure all columns from model are present (simple additive migration) + try: + cursor.execute( + """ + SELECT column_name, data_type + FROM information_schema.columns + WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' + """, + (table,), + ) + existing_column_rows = cursor.fetchall() + existing_columns = { + row["column_name"] for row in existing_column_rows + } + existing_column_types = { + row["column_name"]: (row["data_type"] or "").lower() + for row in existing_column_rows + } - # Desired columns based on model - model_fields = getModelFields(model_class) - desired_columns = set(["id"]) | set(model_fields.keys()) + model_fields = getModelFields(model_class) + desired_columns = set(["id"]) | set(model_fields.keys()) - # Add missing columns - for col in sorted(desired_columns - existing_columns): - # Determine SQL type - if col in ["id"]: - continue # primary key exists already - sql_type = model_fields.get(col) - if not sql_type: - sql_type = "TEXT" - try: - cursor.execute( - f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}' - ) - logger.info( - f"Added missing column '{col}' ({sql_type}) to '{table}'" - ) - except Exception as add_err: - logger.warning( - f"Could not add column '{col}' to '{table}': {add_err}" - ) - - # Column type migrations for existing tables. - # TEXT→DOUBLE PRECISION handles three value shapes: - # 1. NULL / empty string → NULL - # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT - # 3. Plain numeric string like "3.14" → direct cast - _TEXT_TO_DOUBLE = ( - 'DOUBLE PRECISION USING CASE' - ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL' - ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\'' - ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)' - ' ELSE NULLIF("{col}", \'\')::double precision' - ' END' - ) - _SAFE_TYPE_CHANGES = { - ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text", - ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE, - ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer", - ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')', - ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")', - ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')', - } - for col in sorted(desired_columns & existing_columns): - if col == "id": - continue - desired_sql = (model_fields.get(col) or "").upper() - currentType = existing_column_types.get(col, "") - migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql)) - if migration: - castExpr = migration.replace("{col}", col) + for col in sorted(desired_columns - existing_columns): + if col in ["id"]: + continue + sql_type = model_fields.get(col) + if not sql_type: + sql_type = "TEXT" try: - cursor.execute('SAVEPOINT col_migrate') cursor.execute( - f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}' + f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}' ) - cursor.execute('RELEASE SAVEPOINT col_migrate') logger.info( - f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'" + f"Added missing column '{col}' ({sql_type}) to '{table}'" ) - except Exception as alter_err: - cursor.execute('ROLLBACK TO SAVEPOINT col_migrate') + except Exception as add_err: logger.warning( - f"Could not migrate column '{col}' on '{table}': {alter_err}" + f"Could not add column '{col}' to '{table}': {add_err}" ) - except Exception as ensure_err: - logger.warning( - f"Could not ensure columns for existing table '{table}': {ensure_err}" - ) - self.connection.commit() + # Column type migrations for existing tables. + # TEXT→DOUBLE PRECISION handles three value shapes: + # 1. NULL / empty string → NULL + # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT + # 3. Plain numeric string like "3.14" → direct cast + _TEXT_TO_DOUBLE = ( + 'DOUBLE PRECISION USING CASE' + ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL' + ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\'' + ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)' + ' ELSE NULLIF("{col}", \'\')::double precision' + ' END' + ) + _SAFE_TYPE_CHANGES = { + ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text", + ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE, + ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer", + ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')', + ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")', + ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')', + } + for col in sorted(desired_columns & existing_columns): + if col == "id": + continue + desired_sql = (model_fields.get(col) or "").upper() + currentType = existing_column_types.get(col, "") + migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql)) + if migration: + castExpr = migration.replace("{col}", col) + try: + cursor.execute('SAVEPOINT col_migrate') + cursor.execute( + f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}' + ) + cursor.execute('RELEASE SAVEPOINT col_migrate') + logger.info( + f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'" + ) + except Exception as alter_err: + cursor.execute('ROLLBACK TO SAVEPOINT col_migrate') + logger.warning( + f"Could not migrate column '{col}' on '{table}': {alter_err}" + ) + except Exception as ensure_err: + logger.warning( + f"Could not ensure columns for existing table '{table}': {ensure_err}" + ) return True except Exception as e: logger.error(f"Error ensuring table {table} exists: {e}") - if hasattr(self, "connection") and self.connection: - self.connection.rollback() return False def _ensureVectorExtension(self) -> bool: @@ -666,17 +886,14 @@ class DatabaseConnector: if self._vectorExtensionEnabled: return True try: - self._ensure_connection() - with self.connection.cursor() as cursor: - cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") self._vectorExtensionEnabled = True logger.info("pgvector extension enabled") return True except Exception as e: logger.error(f"Failed to enable pgvector extension: {e}") - if hasattr(self, "connection") and self.connection: - self.connection.rollback() return False def _create_table_from_model(self, cursor, table: str, model_class: type) -> None: @@ -791,22 +1008,19 @@ class DatabaseConnector: if not self._ensureTableExists(model_class): return None - with self.connection.cursor() as cursor: - cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) - row = cursor.fetchone() - if not row: - return None + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) + row = cursor.fetchone() + if not row: + return None + record = dict(row) - # Convert row to dict and handle JSONB fields - record = dict(row) - fields = getModelFields(model_class) - - parseRecordFields(record, fields, f"record {recordId}") - - return record + fields = getModelFields(model_class) + parseRecordFields(record, fields, f"record {recordId}") + return record except Exception as e: logger.error(f"Error loading record {recordId} from table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: @@ -849,14 +1063,12 @@ class DatabaseConnector: if effective_user_id: record["sysModifiedBy"] = effective_user_id - with self.connection.cursor() as cursor: - self._save_record(cursor, table, recordId, record, model_class) - - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + self._save_record(cursor, table, recordId, record, model_class) return True except Exception as e: logger.error(f"Error saving record {recordId} to table {table}: {e}") - self.connection.rollback() return False def _loadTable(self, model_class: type) -> List[Dict[str, Any]]: @@ -870,33 +1082,32 @@ class DatabaseConnector: if not self._ensureTableExists(model_class): return [] - with self.connection.cursor() as cursor: - cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') - records = [dict(row) for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(model_class) - modelFields = model_class.model_fields - for record in records: - parseRecordFields(record, fields, f"table {table}") - # Set type-aware defaults for NULL JSONB fields - for fieldName, fieldType in fields.items(): - if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: - fieldInfo = modelFields.get(fieldName) - if fieldInfo: - fieldAnnotation = fieldInfo.annotation - if (fieldAnnotation == list or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is list)): - record[fieldName] = [] - elif (fieldAnnotation == dict or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is dict)): - record[fieldName] = {} + fields = getModelFields(model_class) + modelFields = model_class.model_fields + for record in records: + parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} - return records + return records except Exception as e: logger.error(f"Error loading table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def _registerInitialId(self, table: str, initialId: str) -> bool: @@ -969,28 +1180,20 @@ class DatabaseConnector: def getTables(self) -> List[str]: """Returns a list of all available tables.""" - tables = [] - + tables: List[str] = [] try: - # Ensure connection is alive - self._ensure_connection() - - if not self.connection or self.connection.closed: - logger.error("Database connection is not available") - return tables - - with self.connection.cursor() as cursor: - cursor.execute(""" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = 'public' - ORDER BY table_name - """) - rows = cursor.fetchall() - tables = [row["table_name"] for row in rows] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + ORDER BY table_name + """) + rows = cursor.fetchall() + tables = [row["table_name"] for row in rows] except Exception as e: logger.error(f"Error reading the database {self.dbDatabase}: {e}") - return tables def getFields(self, model_class: type) -> List[str]: @@ -1060,43 +1263,42 @@ class DatabaseConnector: query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"' - with self.connection.cursor() as cursor: - cursor.execute(query, where_values) - records = [dict(row) for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(query, where_values) + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(model_class) - modelFields = model_class.model_fields + fields = getModelFields(model_class) + modelFields = model_class.model_fields + for record in records: + parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} + + if fieldFilter and isinstance(fieldFilter, list): + result = [] for record in records: - parseRecordFields(record, fields, f"table {table}") - for fieldName, fieldType in fields.items(): - if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: - fieldInfo = modelFields.get(fieldName) - if fieldInfo: - fieldAnnotation = fieldInfo.annotation - if (fieldAnnotation == list or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is list)): - record[fieldName] = [] - elif (fieldAnnotation == dict or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is dict)): - record[fieldName] = {} + filteredRecord = {} + for field in fieldFilter: + if field in record: + filteredRecord[field] = record[field] + result.append(filteredRecord) + return result - # If fieldFilter is available, reduce the fields - if fieldFilter and isinstance(fieldFilter, list): - result = [] - for record in records: - filteredRecord = {} - for field in fieldFilter: - if field in record: - filteredRecord[field] = record[field] - result.append(filteredRecord) - return result - - return records + return records except Exception as e: logger.error(f"Error loading records from table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def _buildPaginationClauses( @@ -1281,35 +1483,36 @@ class DatabaseConnector: where_clause, order_clause, limit_clause, values, count_values = \ self._buildPaginationClauses(model_class, pagination, recordFilter) - with self.connection.cursor() as cursor: - countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' - dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' - cursor.execute(countSql, count_values) - totalItems = cursor.fetchone()["count"] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' + dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' + cursor.execute(countSql, count_values) + totalItems = cursor.fetchone()["count"] - cursor.execute(dataSql, values) - records = [dict(row) for row in cursor.fetchall()] + cursor.execute(dataSql, values) + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(model_class) - modelFields = model_class.model_fields - for record in records: - parseRecordFields(record, fields, f"table {table}") - for fieldName, fieldType in fields.items(): - if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: - fieldInfo = modelFields.get(fieldName) - if fieldInfo: - fieldAnnotation = fieldInfo.annotation - if (fieldAnnotation == list or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is list)): - record[fieldName] = [] - elif (fieldAnnotation == dict or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is dict)): - record[fieldName] = {} + fields = getModelFields(model_class) + modelFields = model_class.model_fields + for record in records: + parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} - if fieldFilter and isinstance(fieldFilter, list): - records = [{f: r[f] for f in fieldFilter if f in r} for r in records] + if fieldFilter and isinstance(fieldFilter, list): + records = [{f: r[f] for f in fieldFilter if f in r} for r in records] from modules.routes.routeHelpers import enrichRowsWithFkLabels enrichRowsWithFkLabels(records, model_class) @@ -1320,7 +1523,6 @@ class DatabaseConnector: return {"items": records, "totalItems": totalItems, "totalPages": totalPages} except Exception as e: logger.error(f"Error in getRecordsetPaginated for table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def getDistinctColumnValues( @@ -1365,25 +1567,24 @@ class DatabaseConnector: else: sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val' - with self.connection.cursor() as cursor: - cursor.execute(sql, values) - result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(sql, values) + result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()] - if includeEmpty: - emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\'' - if where_clause: - emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1' - else: - emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1' - with self.connection.cursor() as cursor: - cursor.execute(emptySql, values) - if cursor.fetchone(): - result.append(None) + if includeEmpty: + emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\'' + if where_clause: + emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1' + else: + emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1' + cursor.execute(emptySql, values) + if cursor.fetchone(): + result.append(None) return result except Exception as e: logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def recordCreate( @@ -1463,33 +1664,33 @@ class DatabaseConnector: if not self._ensureTableExists(model_class): return False - with self.connection.cursor() as cursor: - # Check if record exists - cursor.execute( - f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) - ) - if not cursor.fetchone(): - return False + # `getInitialId` opens its own borrow; do it BEFORE we acquire a + # connection ourselves so we don't pin two slots concurrently. + initialId = self.getInitialId(model_class) - # Check if it's an initial record - initialId = self.getInitialId(model_class) - if initialId is not None and initialId == recordId: - self._removeInitialId(table) - logger.info( - f"Initial ID {recordId} for table {table} has been removed from the system table" + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) ) + if not cursor.fetchone(): + return False - # Delete the record - cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,)) + if initialId is not None and initialId == recordId: + # `_removeInitialId` borrows its own conn — done outside + # this block on purpose to avoid nested borrows. + pass + cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,)) - # No cache to update - database handles consistency - - self.connection.commit() + if initialId is not None and initialId == recordId: + self._removeInitialId(table) + logger.info( + f"Initial ID {recordId} for table {table} has been removed from the system table" + ) return True except Exception as e: logger.error(f"Error deleting record {recordId} from table {table}: {e}") - self.connection.rollback() return False def recordCreateBulk( @@ -1559,16 +1760,11 @@ class DatabaseConnector: ) try: - self._ensure_connection() - with self.connection.cursor() as cursor: - psycopg2.extras.execute_values(cursor, sql, rows, page_size=500) - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + psycopg2.extras.execute_values(cursor, sql, rows, page_size=500) except Exception as e: logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}") - try: - self.connection.rollback() - except Exception: - pass raise if self.getInitialId(model_class) is None and normalised: @@ -1649,26 +1845,21 @@ class DatabaseConnector: initialId = self.getInitialId(model_class) try: - self._ensure_connection() - with self.connection.cursor() as cursor: - if initialId is not None: - cursor.execute( - f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql, - [initialId, *params], - ) - initialIsAffected = cursor.fetchone() is not None - else: - initialIsAffected = False + with self.borrowConn() as conn: + with conn.cursor() as cursor: + if initialId is not None: + cursor.execute( + f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql, + [initialId, *params], + ) + initialIsAffected = cursor.fetchone() is not None + else: + initialIsAffected = False - cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params) - deleted = cursor.rowcount or 0 - self.connection.commit() + cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params) + deleted = cursor.rowcount or 0 except Exception as e: logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}") - try: - self.connection.rollback() - except Exception: - pass raise if deleted and initialIsAffected: @@ -1751,39 +1942,30 @@ class DatabaseConnector: ) params = [vectorStr] + whereValues + [vectorStr, limit] - with self.connection.cursor() as cursor: - cursor.execute(query, params) - records = [dict(row) for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(modelClass) - for record in records: - parseRecordFields(record, fields, f"semanticSearch {table}") - - return records + fields = getModelFields(modelClass) + for record in records: + parseRecordFields(record, fields, f"semanticSearch {table}") + return records except Exception as e: logger.error(f"Error in semantic search on {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def close(self, forceClose: bool = False): - """Close the database connection. - - Shared cached connectors are intentionally kept open unless forceClose=True. - This prevents accidental shutdown from interface __del__ methods while - other requests are still using the same cached connector instance. + """No-op for backward compatibility. + + Connections are now owned by the `_PoolRegistry` pool and live for the + process lifetime. Pool shutdown happens centrally via `closeAllPools()` + from the FastAPI lifespan hook — never from a connector instance. + Interface `__del__` paths used to call `close()` to release a per- + connector socket; with pooling there is nothing to close here. """ - if self._isCachedShared and not forceClose: - return - if ( - hasattr(self, "connection") - and self.connection - and not self.connection.closed - ): - self.connection.close() + return def __del__(self): - """Cleanup method to close connection.""" - try: - self.close() - except Exception: - pass + """Cleanup hook (intentionally no-op — see `close`).""" + return diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py index 1fbaf06f..0637d0e9 100644 --- a/modules/features/realEstate/interfaceFeatureRealEstate.py +++ b/modules/features/realEstate/interfaceFeatureRealEstate.py @@ -342,7 +342,7 @@ class RealEstateObjects: # If no exact match, try case-insensitive search via SQL query # This handles cases where the name might have different casing self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1', (name,) @@ -375,7 +375,7 @@ class RealEstateObjects: # Try case-insensitive search self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1', (name,) @@ -408,7 +408,7 @@ class RealEstateObjects: # Try case-insensitive search self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1', (name,) @@ -840,7 +840,7 @@ class RealEstateObjects: # Ensure connection is alive self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: # Execute query if parameters: # Use parameterized query for safety diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index 25f022af..273583d9 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -1659,7 +1659,7 @@ class BillingObjects: try: appInterface = getAppInterface(self.currentUser) appInterface.db._ensure_connection() - with appInterface.db.connection.cursor() as cur: + with appInterface.db.borrowCursor() as cur: if appInterface.db._ensureTableExists(UserInDB): cur.execute( 'SELECT "id" FROM "UserInDB" WHERE ' @@ -1780,7 +1780,7 @@ class BillingObjects: try: self.db._ensure_connection() - with self.db.connection.cursor() as cur: + with self.db.borrowCursor() as cur: countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' cur.execute(countSql, whereValues) totalItems = cur.fetchone()["count"] @@ -1797,10 +1797,7 @@ class BillingObjects: except Exception as e: logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True) - try: - self.db.connection.rollback() - except Exception: - pass + # Rollback is handled by `borrowCursor()` context manager on exit. return {"items": [], "totalItems": 0, "totalPages": 0} def _buildScopeFilter( @@ -1872,7 +1869,7 @@ class BillingObjects: result: Dict[str, Any] = {} - with self.db.connection.cursor() as cur: + with self.db.borrowCursor() as cur: # 1) Totals cur.execute( f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}', @@ -1947,17 +1944,12 @@ class BillingObjects: }) result["timeSeries"] = timeSeries - self.db.connection.commit() - + # Commit/rollback are handled by `borrowCursor()` context manager. result["_allAccounts"] = allAccounts return result except Exception as e: logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True) - try: - self.db.connection.rollback() - except Exception: - pass return self._emptyStats() @staticmethod diff --git a/modules/interfaces/interfaceDbKnowledge.py b/modules/interfaces/interfaceDbKnowledge.py index 31a5af61..d7a445bd 100644 --- a/modules/interfaces/interfaceDbKnowledge.py +++ b/modules/interfaces/interfaceDbKnowledge.py @@ -228,6 +228,22 @@ class KnowledgeObjects: """Get all ContentChunks for a file.""" return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) + def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]: + """Return a {fileId: chunkCount} mapping for the given file IDs. + + One aggregate query instead of N round trips. Used by RAG inventory + to display real chunk counts per DataSource without loading the + embedding vectors. Missing file IDs map to 0 in the caller's logic. + """ + if not fileIds: + return {} + if not self.db._ensureTableExists(ContentChunk): + return {} + sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"' + with self.db.borrowCursor() as cursor: + cursor.execute(sql, (list(fileIds),)) + return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()} + def deleteContentChunks(self, fileId: str) -> int: """Delete all ContentChunks for a file. Returns count of deleted chunks.""" chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py index 6a3c27b5..4dc8a206 100644 --- a/modules/interfaces/interfaceDbManagement.py +++ b/modules/interfaces/interfaceDbManagement.py @@ -1221,22 +1221,17 @@ class ComponentObjects: for item in fileRows ] - # Single transaction: delete FileData, FileItem, then FileFolder (children first) - self.db._ensure_connection() - try: - with self.db.connection.cursor() as cursor: - if fileIds: - cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,)) - cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,)) - orderedIds = list(folderIds) - orderedIds.remove(folderId) - orderedIds.append(folderId) - if orderedIds: - cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,)) - self.db.connection.commit() - except Exception: - self.db.connection.rollback() - raise + # Single transaction: delete FileData, FileItem, then FileFolder (children first). + # Commit/rollback are handled by `borrowCursor()` on exit. + with self.db.borrowCursor() as cursor: + if fileIds: + cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,)) + cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,)) + orderedIds = list(folderIds) + orderedIds.remove(folderId) + orderedIds.append(folderId) + if orderedIds: + cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,)) return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)} @@ -1507,7 +1502,7 @@ class ComponentObjects: try: self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)', (uniqueIds,), @@ -1526,11 +1521,10 @@ class ComponentObjects: cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,)) deletedFiles = cursor.rowcount - self.db.connection.commit() + # Commit/rollback are handled by `borrowCursor()` context manager. return {"deletedFiles": deletedFiles} except Exception as e: logger.error(f"Error deleting files in batch: {e}") - self.db.connection.rollback() raise FileDeletionError(f"Error deleting files in batch: {str(e)}") def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]: diff --git a/modules/interfaces/interfaceRbac.py b/modules/interfaces/interfaceRbac.py index e41485e0..948609ef 100644 --- a/modules/interfaces/interfaceRbac.py +++ b/modules/interfaces/interfaceRbac.py @@ -374,7 +374,7 @@ def getRecordsetWithRBAC( query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}' - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: cursor.execute(query, whereValues) records = [dict(row) for row in cursor.fetchall()] @@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC( offset = (pagination.page - 1) * pagination.pageSize limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}" - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' cursor.execute(countSql, countValues) totalItems = cursor.fetchone()["count"] @@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC( sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val' - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: cursor.execute(sql, whereValues) result = [row["val"] for row in cursor.fetchall()] @@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC( emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1' else: emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1' - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: cursor.execute(emptySql, whereValues) if cursor.fetchone(): result.append(None) @@ -967,7 +967,7 @@ def buildRbacWhereClause( # Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate if table == "UserInDB": try: - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: # Get all user IDs that are members of the current mandate cursor.execute( 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', @@ -994,7 +994,7 @@ def buildRbacWhereClause( # For UserConnection: Filter via UserMandate junction table elif table == "UserConnection": try: - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: # Get all user IDs that are members of the current mandate cursor.execute( 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', diff --git a/modules/routes/routeHelpers.py b/modules/routes/routeHelpers.py index f1d88e31..bb1386af 100644 --- a/modules/routes/routeHelpers.py +++ b/modules/routes/routeHelpers.py @@ -305,7 +305,7 @@ def handleIdsMode( sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"' - with db.connection.cursor() as cursor: + with db.borrowCursor() as cursor: cursor.execute(sql, values) return JSONResponse(content=[row["val"] for row in cursor.fetchall()]) except Exception as e: diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py index 074b5b85..7c426d77 100644 --- a/modules/routes/routeRagInventory.py +++ b/modules/routes/routeRagInventory.py @@ -25,6 +25,18 @@ router = APIRouter( def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]: + """Build per-connection RAG inventory rows. + + Each DataSource row exposes BOTH numbers because they mean different things: + * `fileCount` — distinct files indexed (== `FileContentIndex` rows) + * `chunkCount` — embedding-sized text fragments (== `ContentChunk` rows, + max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval + actually hits) + + A single PDF typically yields 1 file × 5–100 chunks; legacy UI labelled + `len(FileContentIndex)` as "chunks" which was off by 1–2 orders of + magnitude and misleading. + """ from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelKnowledge import FileContentIndex @@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId}) - connChunkTotal = len(connIndexRows) + connFileTotal = len(connIndexRows) + # Map fileId → real chunk count via 1 aggregate query (cheap even for + # connections with thousands of files; we never load the vector body). + fileIds = [ + (idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")) + for idx in connIndexRows + ] + fileIds = [fid for fid in fileIds if fid] + chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {} + connChunkTotal = sum(chunkCountByFile.values()) + + filesByDs: Dict[str, int] = {} chunksByDs: Dict[str, int] = {} - unassigned = 0 + unassignedFiles = 0 + unassignedChunks = 0 for idx in connIndexRows: + fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "") + chunkCnt = chunkCountByFile.get(fileId, 0) struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {} ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {} prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {} dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else "" if dsIdRef: - chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1 + filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1 + chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt else: - unassigned += 1 + unassignedFiles += 1 + unassignedChunks += chunkCnt seen: Dict[str, bool] = {} dsItems = [] @@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), + "fileCount": filesByDs.get(dsId, 0), "chunkCount": chunksByDs.get(dsId, 0), }) - if unassigned > 0 and len(dsItems) > 0: - perDs = unassigned // len(dsItems) - remainder = unassigned % len(dsItems) + # Spread orphan files (provenance lost) evenly so totals match. + if unassignedFiles > 0 and len(dsItems) > 0: + perFile = unassignedFiles // len(dsItems) + remFile = unassignedFiles % len(dsItems) + perChunk = unassignedChunks // len(dsItems) + remChunk = unassignedChunks % len(dsItems) for i, item in enumerate(dsItems): - item["chunkCount"] += perDs + (1 if i < remainder else 0) + item["fileCount"] += perFile + (1 if i < remFile else 0) + item["chunkCount"] += perChunk + (1 if i < remChunk else 0) # Pull a wider window than the previous 5 so the "last successful # sync" is found even if a connection has many recent jobs queued. @@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "skippedPolicy": result.get("skippedPolicy", 0), "failed": result.get("failed", 0), "durationMs": result.get("durationMs", 0), + # Surface limit-stop reason so the UI can warn the user + # that the index is provably incomplete (and which budget + # to raise). None means the walker finished naturally. + "stoppedAtLimit": result.get("stoppedAtLimit"), + "limits": result.get("limits") or {}, + "bytesProcessed": result.get("bytesProcessed", 0), } if lastError and lastSuccess: break @@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False), "preferences": getattr(conn, "knowledgePreferences", None) or {}, "dataSources": dsItems, + "totalFiles": connFileTotal, "totalChunks": connChunkTotal, "runningJobs": runningJobs, "lastError": lastError, @@ -139,8 +179,9 @@ def _getInventoryMe( items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) + totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"chunks": totalChunks}} + return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} except Exception as e: logger.error("Error in RAG inventory /me: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @@ -170,9 +211,10 @@ def _getInventoryMandate( items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) + totalFiles = sum(c.get("totalFiles", 0) for c in items) totalBytes = aggregateMandateRagTotalBytes(mandateId) - return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}} + return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}} except HTTPException: raise except Exception as e: @@ -202,8 +244,9 @@ def _getInventoryPlatform( items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) + totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"chunks": totalChunks}} + return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} except HTTPException: raise except Exception as e: diff --git a/modules/routes/routeWorkflowDashboard.py b/modules/routes/routeWorkflowDashboard.py index d83ce1b2..85b372a1 100644 --- a/modules/routes/routeWorkflowDashboard.py +++ b/modules/routes/routeWorkflowDashboard.py @@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s) GROUP BY "workflowId" """ out: dict = {} - with db.connection.cursor() as cursor: + with db.borrowCursor() as cursor: cursor.execute(sql, (workflowIds,)) for row in cursor.fetchall(): r = dict(row) @@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated( dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}" db._ensure_connection() - with db.connection.cursor() as cursor: + with db.borrowCursor() as cursor: cursor.execute(countSql, countValues) totalItems = int(cursor.fetchone()["cnt"]) diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py index 4fbea490..bdb3d23b 100644 --- a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py +++ b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py @@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300 def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str): - """Reuse a pooled DB connector for the given feature database.""" + """Reuse a pooled DB connector for the given feature database. + + The underlying psycopg2 connections live in the central pool + (`_PoolRegistry`) and are recreated on demand if they go stale; we just + need to keep the lightweight connector wrapper around. + """ if featureDbName in _featureDbConnPool: - conn = _featureDbConnPool[featureDbName] - try: - if conn.connection and not conn.connection.closed: - return conn - except Exception as e: - logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}") - _featureDbConnPool.pop(featureDbName, None) + return _featureDbConnPool[featureDbName] from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py index 8bfa2628..959e42c9 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py @@ -68,6 +68,9 @@ class ClickupBootstrapResult: workspaces: int = 0 lists: int = 0 errors: List[str] = field(default_factory=list) + # First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None. + # Drives the same UI banner as the file-walker bootstraps. + stoppedAtLimit: Optional[str] = None def _syntheticTaskId(connectionId: str, taskId: str) -> str: @@ -225,6 +228,7 @@ async def bootstrapClickup( cancelled = False for ds in dataSources: if result.indexed + result.skippedDuplicate >= limits.maxTasks: + _recordLimitStop(result, "maxTasks", "dataSource", limits) break if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): cancelled = True @@ -243,8 +247,11 @@ async def bootstrapClickup( clickupScope=limits.clickupScope, ) + if len(teams) > dsLimits.maxWorkspaces: + _recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False) for team in teams[:dsLimits.maxWorkspaces]: if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks: + _recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits) break teamId = str(team.get("id", "") or "") if not teamId: @@ -351,6 +358,7 @@ async def _walkTeam( for lst in listsCollected: if result.indexed + result.skippedDuplicate >= limits.maxTasks: + _recordLimitStop(result, "maxTasks", f"team={teamId}", limits) return if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): return @@ -407,6 +415,7 @@ async def _walkList( for task in tasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks: + _recordLimitStop(result, "maxTasks", f"list={listId}", limits) return if not _isRecent(task.get("date_updated"), limits.maxAgeDays): result.skippedPolicy += 1 @@ -529,13 +538,37 @@ async def _ingestTask( ) +def _recordLimitStop( + result: ClickupBootstrapResult, + limitName: str, + where: str, + limits: ClickupBootstrapLimits, + *, + hard: bool = True, +) -> None: + """See subConnectorSyncSharepoint._recordLimitStop for semantics.""" + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxTasks": limits.maxTasks, + "maxWorkspaces": limits.maxWorkspaces, + "maxListsPerWorkspace": limits.maxListsPerWorkspace, + } + logger.warning( + "clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).", + limitName, budgetMap.get(limitName), where, + result.indexed, result.skippedDuplicate, + ) + + def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.workspaces, result.lists, durationMs, + result.stoppedAtLimit or "none", extra={ "event": "ingestion.connection.bootstrap.done", "part": "clickup", @@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: "workspaces": result.workspaces, "lists": result.lists, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit, }, ) return { @@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: "lists": result.lists, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxTasks": MAX_TASKS_DEFAULT, + "maxWorkspaces": MAX_WORKSPACES_DEFAULT, + "maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT, + "maxAgeDays": MAX_AGE_DAYS_DEFAULT, + }, } diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py index 5dd1bd8b..e27abacb 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py @@ -61,6 +61,8 @@ class GdriveBootstrapResult: failed: int = 0 bytesProcessed: int = 0 errors: List[str] = field(default_factory=list) + # See SharepointBootstrapResult.stoppedAtLimit — same semantics. + stoppedAtLimit: Optional[str] = None def _syntheticFileId(connectionId: str, externalItemId: str) -> str: @@ -265,8 +267,10 @@ async def _walkFolder( for entry in entries: if result.indexed + result.skippedDuplicate >= limits.maxItems: + _recordLimitStop(result, "maxItems", folderPath, limits) return if result.bytesProcessed >= limits.maxBytes: + _recordLimitStop(result, "maxBytes", folderPath, limits) return if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): return @@ -276,6 +280,9 @@ async def _walkFolder( mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType") if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME: + if depth + 1 > limits.maxDepth: + _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False) + continue await _walkFolder( adapter=adapter, knowledgeService=knowledgeService, @@ -298,6 +305,7 @@ async def _walkFolder( continue size = int(getattr(entry, "size", 0) or 0) if size and size > limits.maxFileSize: + _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False) result.skippedPolicy += 1 continue modifiedTime = metadata.get("modifiedTime") @@ -470,13 +478,38 @@ async def _ingestOne( await asyncio.sleep(0) +def _recordLimitStop( + result: GdriveBootstrapResult, + limitName: str, + where: str, + limits: GdriveBootstrapLimits, + *, + hard: bool = True, +) -> None: + """See subConnectorSyncSharepoint._recordLimitStop for semantics.""" + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxItems": limits.maxItems, + "maxBytes": limits.maxBytes, + "maxDepth": limits.maxDepth, + "maxFileSize": limits.maxFileSize, + } + logger.warning( + "gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).", + limitName, budgetMap.get(limitName), where, + result.indexed, result.bytesProcessed, + ) + + def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.bytesProcessed, durationMs, + result.stoppedAtLimit or "none", extra={ "event": "ingestion.connection.bootstrap.done", "part": "gdrive", @@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f "failed": result.failed, "bytes": result.bytesProcessed, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit, }, ) return { @@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f "bytesProcessed": result.bytesProcessed, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxItems": MAX_ITEMS_DEFAULT, + "maxBytes": MAX_BYTES_DEFAULT, + "maxFileSize": MAX_FILE_SIZE_DEFAULT, + "maxDepth": MAX_DEPTH_DEFAULT, + }, } diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py index e656abe8..dcf19e39 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py @@ -53,6 +53,8 @@ class KdriveBootstrapResult: failed: int = 0 bytesProcessed: int = 0 errors: List[str] = field(default_factory=list) + # See SharepointBootstrapResult.stoppedAtLimit — same semantics. + stoppedAtLimit: Optional[str] = None def _syntheticFileId(connectionId: str, externalItemId: str) -> str: @@ -232,14 +234,19 @@ async def _walkFolder( for entry in entries: if result.indexed + result.skippedDuplicate >= limits.maxItems: + _recordLimitStop(result, "maxItems", folderPath, limits) return if result.bytesProcessed >= limits.maxBytes: + _recordLimitStop(result, "maxBytes", folderPath, limits) return if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): return entryPath = getattr(entry, "path", "") or "" if getattr(entry, "isFolder", False): + if depth + 1 > limits.maxDepth: + _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False) + continue await _walkFolder( adapter=adapter, knowledgeService=knowledgeService, @@ -262,6 +269,7 @@ async def _walkFolder( continue size = int(getattr(entry, "size", 0) or 0) if size and size > limits.maxFileSize: + _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False) result.skippedPolicy += 1 continue @@ -415,17 +423,42 @@ async def _ingestOne( await asyncio.sleep(0) +def _recordLimitStop( + result: KdriveBootstrapResult, + limitName: str, + where: str, + limits: KdriveBootstrapLimits, + *, + hard: bool = True, +) -> None: + """See subConnectorSyncSharepoint._recordLimitStop for semantics.""" + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxItems": limits.maxItems, + "maxBytes": limits.maxBytes, + "maxDepth": limits.maxDepth, + "maxFileSize": limits.maxFileSize, + } + logger.warning( + "kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).", + limitName, budgetMap.get(limitName), where, + result.indexed, result.bytesProcessed, + ) + + def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, - durationMs, + durationMs, result.stoppedAtLimit or "none", extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive", "connectionId": connectionId, "indexed": result.indexed, "skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy, - "failed": result.failed, "durationMs": durationMs}, + "failed": result.failed, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit}, ) return { "connectionId": result.connectionId, @@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f "bytesProcessed": result.bytesProcessed, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxItems": MAX_ITEMS_DEFAULT, + "maxBytes": MAX_BYTES_DEFAULT, + "maxFileSize": MAX_FILE_SIZE_DEFAULT, + "maxDepth": MAX_DEPTH_DEFAULT, + }, } diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py index 892e41ba..e06fd36b 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py @@ -59,6 +59,10 @@ class SharepointBootstrapResult: failed: int = 0 bytesProcessed: int = 0 errors: List[str] = field(default_factory=list) + # First budget that hit zero; None means the walk completed naturally. + # Surfaces in the bootstrap result so the RAG inventory UI can warn the + # user that the corpus is incomplete and tell them which knob to turn. + stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None def _syntheticFileId(connectionId: str, externalItemId: str) -> str: @@ -259,14 +263,22 @@ async def _walkFolder( for entry in entries: if result.indexed + result.skippedDuplicate >= limits.maxItems: + _recordLimitStop(result, "maxItems", folderPath, limits) return if result.bytesProcessed >= limits.maxBytes: + _recordLimitStop(result, "maxBytes", folderPath, limits) return if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): return entryPath = getattr(entry, "path", "") or "" if getattr(entry, "isFolder", False): + if depth + 1 > limits.maxDepth: + # We stop descending here but keep walking siblings. + # Record once per bootstrap so the UI shows "maxDepth" even + # if other budgets aren't exhausted yet. + _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False) + continue await _walkFolder( adapter=adapter, knowledgeService=knowledgeService, @@ -289,6 +301,7 @@ async def _walkFolder( continue size = int(getattr(entry, "size", 0) or 0) if size and size > limits.maxFileSize: + _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False) result.skippedPolicy += 1 continue @@ -443,13 +456,44 @@ async def _ingestOne( await asyncio.sleep(0) +def _recordLimitStop( + result: SharepointBootstrapResult, + limitName: str, + where: str, + limits: SharepointBootstrapLimits, + *, + hard: bool = True, +) -> None: + """Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder + maxDepth) only record when no hard limit has yet stopped the run, so the UI + surfaces the most important reason. + + Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded + soft limit — once a hard cap is hit, the corpus is provably incomplete. + """ + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxItems": limits.maxItems, + "maxBytes": limits.maxBytes, + "maxDepth": limits.maxDepth, + "maxFileSize": limits.maxFileSize, + } + logger.warning( + "sharepoint walker hit %s=%s at %s — partial index " + "(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.", + limitName, budgetMap.get(limitName), where, + result.indexed, result.bytesProcessed, + ) + + def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, - durationMs, + durationMs, result.stoppedAtLimit or "none", extra={ "event": "ingestion.connection.bootstrap.done", "part": "sharepoint", @@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM "skippedPolicy": result.skippedPolicy, "failed": result.failed, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit, }, ) return { @@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM "bytesProcessed": result.bytesProcessed, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxItems": MAX_ITEMS_DEFAULT, + "maxBytes": MAX_BYTES_DEFAULT, + "maxFileSize": MAX_FILE_SIZE_DEFAULT, + "maxDepth": MAX_DEPTH_DEFAULT, + }, } diff --git a/modules/shared/configuration.py b/modules/shared/configuration.py index 721ce448..15646962 100644 --- a/modules/shared/configuration.py +++ b/modules/shared/configuration.py @@ -12,7 +12,8 @@ import logging import json import base64 import time -from typing import Any, Dict, Optional +import threading +from typing import Any, Dict, Optional, Tuple from pathlib import Path from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes @@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown # Structure: {user_id: {key_name: [timestamps]}} _decryption_attempts = {} +# Process-wide plaintext cache for decrypted secrets. +# Key: the encrypted ciphertext (which already includes env prefix). +# Value: (expiresAtMonotonic, plaintext). +# TTL is short enough that key rotation propagates quickly, long enough that +# hot DB-init paths (every API call building a connector) don't blow the +# decryption rate limit. 60s is a deliberate compromise. +_DECRYPTION_CACHE_TTL_S = 60.0 +_decryption_cache: Dict[str, Tuple[float, str]] = {} +_decryption_cache_lock = threading.Lock() + def _getMasterKey(envType: str = None) -> bytes: """ Get the master key for the specified environment. @@ -486,25 +497,43 @@ def encryptValue(value: str, envType: str = None, userId: str = "system", keyNam def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "unknown") -> str: """ Decrypt a value using the master key for the current environment. - + + A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted + first. The 10/sec rate-limit on cache misses still protects against + brute-force attacks; cache HITS bypass it because they are not actual + cryptographic operations — they just return the result of an earlier + successful decrypt. Without this cache, hot paths like + `mainBackgroundJobService._getDb()` (called per RAG inventory poll AND + per walker DB call) trigger the rate limit and surface as + "Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'" + ERRORs in the RAG inventory UI route. + Args: encryptedValue: The encrypted value with prefix userId: The user ID making the request (default: "system") keyName: The name of the key being decrypted (default: "unknown") - + Returns: str: The decrypted plain text value - + Raises: ValueError: If decryption fails """ if not _isEncryptedValue(encryptedValue): return encryptedValue # Return as-is if not encrypted - - # Check rate limiting (10 per second per user per key) + + # Cache lookup BEFORE the rate-limit check: a cache hit is not a new + # cryptographic operation and must not be throttled. + now = time.monotonic() + with _decryption_cache_lock: + cached = _decryption_cache.get(encryptedValue) + if cached is not None and cached[0] > now: + return cached[1] + + # Cache miss → real decrypt → apply rate limit. if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10): raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)") - + try: # Extract environment type from prefix if encryptedValue.startswith('DEV_ENC:'): @@ -536,7 +565,7 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un encryptedBytes = base64.urlsafe_b64decode(encryptedPart.encode('utf-8')) decryptedBytes = fernet.decrypt(encryptedBytes) decryptedValue = decryptedBytes.decode('utf-8') - + # Log audit event for decryption try: from modules.shared.auditLogger import audit_logger @@ -549,11 +578,25 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un except Exception: # Don't fail if audit logging fails pass - + + # Populate cache so subsequent reads of the same ciphertext don't + # re-decrypt (and don't consume rate-limit budget). + with _decryption_cache_lock: + _decryption_cache[encryptedValue] = ( + time.monotonic() + _DECRYPTION_CACHE_TTL_S, + decryptedValue, + ) + return decryptedValue - + except Exception as e: raise ValueError(f"Decryption failed: {e}") + +def clearDecryptionCache() -> None: + """Drop all cached plaintext secrets. Call after key rotation or in tests.""" + with _decryption_cache_lock: + _decryption_cache.clear() + # Create the global APP_CONFIG instance APP_CONFIG = Configuration() \ No newline at end of file diff --git a/modules/shared/dbMultiTenantOptimizations.py b/modules/shared/dbMultiTenantOptimizations.py index c178c376..9b5a15b4 100644 --- a/modules/shared/dbMultiTenantOptimizations.py +++ b/modules/shared/dbMultiTenantOptimizations.py @@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None: logger.debug(f"_ensureUamTablesMatchModels: {e}") -def _getConnection(dbConnector): - """Get a connection from the DatabaseConnector. - - Ensures the connection is alive and returns it. - Commits any pending transaction first to avoid blocking. +from contextlib import contextmanager + + +@contextmanager +def _borrowDbConn(dbConnector): + """Borrow a pooled connection from the DatabaseConnector. + + Index/trigger/FK creation traditionally ran with `conn.autocommit = True` + so each CREATE statement is its own transaction (DDL on a managed + connection blocks waiting for COMMIT). This helper preserves that + behaviour on top of the pool: borrow a connection, flip it to autocommit, + yield it, and restore the previous state before returning it to the pool. """ - dbConnector._ensure_connection() - conn = dbConnector.connection - # Commit any pending transaction to avoid blocking - try: - conn.commit() - except Exception: - pass # Ignore if nothing to commit - return conn + with dbConnector.borrowConn() as conn: + try: + previousAutocommit = conn.autocommit + except Exception: + previousAutocommit = False + try: + conn.autocommit = True + except Exception as e: + logger.debug(f"Could not set autocommit on borrowed connection: {e}") + try: + yield conn + finally: + try: + conn.autocommit = previousAutocommit + except Exception: + pass # ============================================================================= @@ -174,73 +189,42 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non } try: - # Get a connection from the connector - conn = _getConnection(dbConnector) - - # Save and set autocommit state - try: - originalAutocommit = conn.autocommit - except Exception: - originalAutocommit = False - - try: - conn.autocommit = True - except Exception as autoErr: - logger.debug(f"Could not set autocommit: {autoErr}") - try: _ensureUamTablesMatchModels(dbConnector) except Exception as preIdxErr: logger.debug(f"Pre-index table ensure: {preIdxErr}") - - try: + + with _borrowDbConn(dbConnector) as conn: with conn.cursor() as cursor: - # Apply indexes results["indexesCreated"] = _applyIndexes(cursor, tables) - - # Apply foreign keys results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables) - - # Apply immutable triggers results["triggersCreated"] = _applyImmutableTriggers(cursor, tables) - - logger.info( - f"Multi-tenant optimizations applied: " - f"{results['indexesCreated']} indexes, " - f"{results['triggersCreated']} triggers, " - f"{results['foreignKeysCreated']} foreign keys" - ) - finally: - # Restore original autocommit state - try: - conn.autocommit = originalAutocommit - except Exception: - pass - + + logger.info( + f"Multi-tenant optimizations applied: " + f"{results['indexesCreated']} indexes, " + f"{results['triggersCreated']} triggers, " + f"{results['foreignKeysCreated']} foreign keys" + ) + except Exception as e: logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}") results["errors"].append(str(e)) - + return results def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int: """Apply only indexes (lighter operation, safe for frequent calls).""" try: - conn = _getConnection(dbConnector) - originalAutocommit = conn.autocommit - conn.autocommit = True - try: _ensureUamTablesMatchModels(dbConnector) except Exception as preIdxErr: logger.debug(f"Pre-index table ensure: {preIdxErr}") - - try: + + with _borrowDbConn(dbConnector) as conn: with conn.cursor() as cursor: return _applyIndexes(cursor, tables) - finally: - conn.autocommit = originalAutocommit except Exception as e: logger.error(f"Error applying indexes: {e}") return 0 @@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict: } try: - conn = _getConnection(dbConnector) - with conn.cursor() as cursor: + with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor: # Check regular indexes for tableName, indexName, _ in _INDEXES: if _tableExists(cursor, tableName): diff --git a/modules/shared/gdprDeletion.py b/modules/shared/gdprDeletion.py index 99e09313..45a9ea43 100644 --- a/modules/shared/gdprDeletion.py +++ b/modules/shared/gdprDeletion.py @@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]: ORDER BY ordinal_position """ - cursor = dbConnector.connection.cursor() - cursor.execute(query, (tableName,)) - columns = [row[0] for row in cursor.fetchall()] - cursor.close() - + with dbConnector.borrowCursor() as cursor: + cursor.execute(query, (tableName,)) + columns = [row[0] for row in cursor.fetchall()] return columns except Exception as e: logger.error(f"Error getting columns for table {tableName}: {e}") @@ -92,29 +90,26 @@ def _getAllTables(dbConnector) -> List[str]: ORDER BY table_name """ - cursor = dbConnector.connection.cursor() - cursor.execute(query) - allTables = [row[0] for row in cursor.fetchall()] - - # Get foreign key relationships to determine dependency order - fkQuery = """ - SELECT - tc.table_name, - ccu.table_name AS foreign_table_name - FROM information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - JOIN information_schema.constraint_column_usage AS ccu - ON ccu.constraint_name = tc.constraint_name - AND ccu.table_schema = tc.table_schema - WHERE tc.constraint_type = 'FOREIGN KEY' - AND tc.table_schema = 'public' - """ - - cursor.execute(fkQuery) - foreignKeys = cursor.fetchall() - cursor.close() + with dbConnector.borrowCursor() as cursor: + cursor.execute(query) + allTables = [row[0] for row in cursor.fetchall()] + + fkQuery = """ + SELECT + tc.table_name, + ccu.table_name AS foreign_table_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = 'public' + """ + cursor.execute(fkQuery) + foreignKeys = cursor.fetchall() # Build dependency graph (child -> parent mapping) dependencies = {} @@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]: # Fallback: return simple list without ordering try: query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'" - cursor = dbConnector.connection.cursor() - cursor.execute(query) - tables = [row[0] for row in cursor.fetchall()] - cursor.close() + with dbConnector.borrowCursor() as cursor: + cursor.execute(query) + tables = [row[0] for row in cursor.fetchall()] return [t for t in tables if t not in PROTECTED_TABLES] except Exception: return [] @@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]: AND i.indisprimary """ - cursor = dbConnector.connection.cursor() - cursor.execute(query, (tableName,)) - pkColumns = [row[0] for row in cursor.fetchall()] - cursor.close() - + with dbConnector.borrowCursor() as cursor: + cursor.execute(query, (tableName,)) + pkColumns = [row[0] for row in cursor.fetchall()] return pkColumns except Exception as e: logger.debug(f"Could not get primary key for {tableName}: {e}") @@ -229,21 +221,15 @@ def _findUserReferencesInTable( return {} references = {} - cursor = dbConnector.connection.cursor() - - for userColumn in userColumns: - # Build SELECT for primary key columns - pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns]) - query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s' - - cursor.execute(query, (userId,)) - recordKeys = cursor.fetchall() - - if recordKeys: - references[userColumn] = recordKeys - logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}") - - cursor.close() + with dbConnector.borrowCursor() as cursor: + for userColumn in userColumns: + pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns]) + query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s' + cursor.execute(query, (userId,)) + recordKeys = cursor.fetchall() + if recordKeys: + references[userColumn] = recordKeys + logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}") return references except Exception as e: @@ -277,42 +263,35 @@ def _anonymizeRecords( return 0 try: - cursor = dbConnector.connection.cursor() + # Resolve column metadata once outside the borrow block (it borrows its + # own connection internally). + columns = _getTableColumns(dbConnector, tableName) + hasModifiedAt = "sysModifiedAt" in columns + count = 0 - - for recordKey in recordKeys: - # Build WHERE clause for primary key - whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) - - # Check if table has sysModifiedAt column - columns = _getTableColumns(dbConnector, tableName) - hasModifiedAt = "sysModifiedAt" in columns - - if hasModifiedAt: - query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}' - params = [anonymousValue, getUtcTimestamp()] - else: - query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}' - params = [anonymousValue] - - # Add primary key values to params - if isinstance(recordKey, tuple): - params.extend(recordKey) - else: - params.append(recordKey) - - cursor.execute(query, params) - count += cursor.rowcount - - dbConnector.connection.commit() - cursor.close() - + with dbConnector.borrowCursor() as cursor: + for recordKey in recordKeys: + whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) + if hasModifiedAt: + query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}' + params = [anonymousValue, getUtcTimestamp()] + else: + query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}' + params = [anonymousValue] + + if isinstance(recordKey, tuple): + params.extend(recordKey) + else: + params.append(recordKey) + + cursor.execute(query, params) + count += cursor.rowcount + logger.info(f"Anonymized {count} records in {tableName}.{columnName}") return count - + except Exception as e: logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}") - dbConnector.connection.rollback() return 0 @@ -338,32 +317,23 @@ def _deleteRecords( return 0 try: - cursor = dbConnector.connection.cursor() count = 0 - - for recordKey in recordKeys: - # Build WHERE clause for primary key - whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) - query = f'DELETE FROM "{tableName}" WHERE {whereClause}' - - # Prepare params - if isinstance(recordKey, tuple): - params = list(recordKey) - else: - params = [recordKey] - - cursor.execute(query, params) - count += cursor.rowcount - - dbConnector.connection.commit() - cursor.close() - + with dbConnector.borrowCursor() as cursor: + for recordKey in recordKeys: + whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) + query = f'DELETE FROM "{tableName}" WHERE {whereClause}' + if isinstance(recordKey, tuple): + params = list(recordKey) + else: + params = [recordKey] + cursor.execute(query, params) + count += cursor.rowcount + logger.info(f"Deleted {count} records from {tableName}") return count - + except Exception as e: logger.error(f"Error deleting records from {tableName}: {e}") - dbConnector.connection.rollback() return 0 diff --git a/scripts/stage0_filefolder_schema_check.py b/scripts/stage0_filefolder_schema_check.py index 861d8671..d172e19c 100644 --- a/scripts/stage0_filefolder_schema_check.py +++ b/scripts/stage0_filefolder_schema_check.py @@ -25,7 +25,7 @@ if not c or not c.connection: print("STAGE0: DB_CONNECTION=none (check config.ini / .env)") raise SystemExit(2) -cur = c.connection.cursor() +cur = c.borrowCursor() def _scalar(cur): diff --git a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py index 4f98ef4a..57094760 100644 --- a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py +++ b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py @@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like: These tests pin the new contract: empty result sets still return ``[]`` / ``None`` (normal), but any exception inside the query path propagates as -``DatabaseQueryError`` with the table name attached. The transaction is -rolled back so the connection is usable for subsequent queries. +``DatabaseQueryError`` with the table name attached. + +Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`) +the connector borrows a connection from `_PoolRegistry` on every call via the +`borrowConn()` context manager. The tests mock that context manager so the +fast-fail contract is exercised without requiring a live Postgres server. """ from __future__ import annotations +from contextlib import contextmanager from unittest.mock import MagicMock import pytest @@ -25,7 +30,6 @@ import psycopg2.errors from modules.connectors.connectorDbPostgre import ( DatabaseConnector, DatabaseQueryError, - _rollbackQuietly, ) @@ -39,26 +43,44 @@ class DummyTable: def _makeConnector(cursorBehavior): - """Build a ``DatabaseConnector`` skeleton with mocked connection/cursor. + """Build a ``DatabaseConnector`` skeleton with a mocked pool borrow. ``cursorBehavior`` is a callable invoked with the cursor mock so the test can configure ``execute``/``fetchall``/``fetchone`` per scenario. + + Returns ``(connector, conn, cursor)``: + * ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can + assert that the borrow lifecycle did the right thing. + * ``cursor`` is the per-test cursor mock. """ connector = DatabaseConnector.__new__(DatabaseConnector) + cursor = MagicMock() + cursorBehavior(cursor) + cursorContext = MagicMock() cursorContext.__enter__ = MagicMock(return_value=cursor) cursorContext.__exit__ = MagicMock(return_value=False) - connection = MagicMock() - connection.cursor.return_value = cursorContext - connector.connection = connection + conn = MagicMock() + conn.cursor.return_value = cursorContext + + @contextmanager + def fakeBorrow(): + try: + yield conn + except Exception: + conn.rollback() + raise + else: + conn.commit() + + connector.borrowConn = fakeBorrow connector._ensureTableExists = MagicMock(return_value=True) connector._systemTableName = "_system" - cursorBehavior(cursor) - return connector, connection, cursor + return connector, conn, cursor class TestGetRecordsetFailLoud: @@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud: def behavior(cursor): cursor.execute.return_value = None cursor.fetchall.return_value = [] - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) result = connector.getRecordset(DummyTable) assert result == [] - connection.rollback.assert_not_called() + conn.rollback.assert_not_called() + conn.commit.assert_called_once() def test_dictAdaptErrorRaisesDatabaseQueryError(self): """Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise.""" @@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud: cursor.execute.side_effect = psycopg2.ProgrammingError( "can't adapt type 'dict'" ) - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError) as excinfo: connector.getRecordset( @@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud: assert excinfo.value.table == "DummyTable" assert "can't adapt type 'dict'" in str(excinfo.value) assert isinstance(excinfo.value.original, psycopg2.ProgrammingError) - connection.rollback.assert_called_once() + conn.rollback.assert_called_once() def test_missingColumnRaisesDatabaseQueryError(self): def behavior(cursor): cursor.execute.side_effect = psycopg2.errors.UndefinedColumn( 'column "wat" does not exist' ) - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError) as excinfo: connector.getRecordset(DummyTable, recordFilter={"wat": "x"}) assert "wat" in str(excinfo.value) - connection.rollback.assert_called_once() + conn.rollback.assert_called_once() def test_operationalErrorRaisesDatabaseQueryError(self): """Connection lost mid-query is also a real failure that must propagate.""" def behavior(cursor): cursor.execute.side_effect = psycopg2.OperationalError("connection lost") - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError): connector.getRecordset(DummyTable) - connection.rollback.assert_called_once() + conn.rollback.assert_called_once() class TestGetRecordFailLoud: @@ -122,37 +145,22 @@ class TestGetRecordFailLoud: def behavior(cursor): cursor.execute.return_value = None cursor.fetchone.return_value = None - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) result = connector.getRecord(DummyTable, "missing-id") assert result is None - connection.rollback.assert_not_called() + conn.rollback.assert_not_called() + conn.commit.assert_called_once() def test_queryErrorRaisesDatabaseQueryError(self): def behavior(cursor): cursor.execute.side_effect = psycopg2.errors.UndefinedTable( 'relation "DummyTable" does not exist' ) - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError) as excinfo: connector.getRecord(DummyTable, "any-id") assert excinfo.value.table == "DummyTable" - connection.rollback.assert_called_once() - - -class TestRollbackQuietly: - def test_rollsBackOnLiveConnection(self): - connection = MagicMock() - _rollbackQuietly(connection) - connection.rollback.assert_called_once() - - def test_swallowsRollbackError(self): - """Rollback failure must not mask the original query error.""" - connection = MagicMock() - connection.rollback.side_effect = RuntimeError("rollback failed") - _rollbackQuietly(connection) - - def test_noopOnNoneConnection(self): - _rollbackQuietly(None) + conn.rollback.assert_called_once() diff --git a/tests/unit/connectors/test_connectorDbPostgre_pool.py b/tests/unit/connectors/test_connectorDbPostgre_pool.py new file mode 100644 index 00000000..9c389add --- /dev/null +++ b/tests/unit/connectors/test_connectorDbPostgre_pool.py @@ -0,0 +1,304 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Concurrency tests for the PostgreSQL connection pool. + +These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md` +refactor delivered: + +* T1 — 50 threads × 100 calls in parallel produce 0 `OperationalError`s and + every call completes within reasonable time (p99 < 2 s). +* T2 — Two threads `_loadRecord` + `_saveRecord` against the same connector + do not corrupt each other's cursors. +* T3 — `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and + releases the connection back into the pool cleanly. + +The tests need a real PostgreSQL server because the bug they guard against +only materialises with real psycopg2 sockets — a mocked connection never +hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads +`.env`) and are auto-skipped when the connection fails (no local Postgres, +wrong creds, etc.) so `pytest` keeps working in CI-only environments. + +To run them locally: + + pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v + +They use a throwaway database name (`poweron_pool_test_`) and drop it +in fixture teardown so they leave nothing behind. +""" +from __future__ import annotations + +import time +import uuid +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +import psycopg2 +import psycopg2.errors +import pytest +from pydantic import Field + +from modules.connectors.connectorDbPostgre import ( + DatabaseConnector, + _PoolRegistry, + closeAllPools, +) +from modules.datamodels.datamodelBase import PowerOnModel +from modules.shared.configuration import APP_CONFIG + + +def _dbConfig(): + """Read DB connection params from APP_CONFIG (`.env`). + + Returns ``None`` when host/user/password are not all present so the + test module can skip cleanly instead of blowing up at import time. + """ + host = APP_CONFIG.get("DB_HOST") + user = APP_CONFIG.get("DB_USER") + password = APP_CONFIG.get("DB_PASSWORD_SECRET") + port = APP_CONFIG.get("DB_PORT", 5432) + if not host or not user or password is None: + return None + return {"host": host, "user": user, "password": password, "port": int(port)} + + +def _canReachPostgres(cfg) -> bool: + """Try a quick connect to the admin DB so we can skip on connection failures.""" + try: + conn = psycopg2.connect( + host=cfg["host"], port=cfg["port"], database="postgres", + user=cfg["user"], password=cfg["password"], connect_timeout=2, + ) + conn.close() + return True + except Exception: # noqa: BLE001 + return False + + +_DB_CFG = _dbConfig() +pytestmark = pytest.mark.skipif( + _DB_CFG is None or not _canReachPostgres(_DB_CFG), + reason="No reachable PostgreSQL — skipping live-Postgres pool tests", +) + + +class PoolTestRow(PowerOnModel): + """Tiny model used to exercise the pool — one ID + one payload field.""" + payload: str = Field(default="", description="Test payload") + + +@pytest.fixture +def liveConnector(): + """Spin up a throwaway database, yield a `DatabaseConnector` against it, + drop the database afterwards. + + The pool registry is wiped before and after each test so state from one + test cannot mask a bug in another. + """ + cfg = _DB_CFG + host = cfg["host"] + user = cfg["user"] + password = cfg["password"] + port = cfg["port"] + dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}" + + # Pre-clean: drop any orphan test DB with the same name (shouldn't happen + # because we use a unique uuid, but be defensive). + adminConn = psycopg2.connect( + host=host, port=port, database="postgres", user=user, password=password + ) + adminConn.autocommit = True + try: + with adminConn.cursor() as cur: + cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"') + finally: + adminConn.close() + + closeAllPools() + + connector = DatabaseConnector( + dbHost=host, + dbDatabase=dbName, + dbUser=user, + dbPassword=password, + dbPort=port, + ) + # Seed exactly one row so every concurrent read has a stable target. + connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"}) + + yield connector + + # Teardown: tear pools down, then drop the DB. + closeAllPools() + adminConn = psycopg2.connect( + host=host, port=port, database="postgres", user=user, password=password + ) + adminConn.autocommit = True + try: + with adminConn.cursor() as cur: + cur.execute( + 'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s', + (dbName,), + ) + cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"') + finally: + adminConn.close() + + +class TestPoolConcurrency: + def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int): + """Run N worker threads, each issuing M reads. Return (errors, latencies).""" + errors: list = [] + latencies: list = [] + lock = threading.Lock() + + def worker(): + for _ in range(callsPerThread): + t0 = time.perf_counter() + try: + rows = liveConnector.getRecordset(PoolTestRow) + assert any(r["id"] == "seed" for r in rows) + except Exception as e: # noqa: BLE001 — we want every failure mode + with lock: + errors.append(e) + finally: + with lock: + latencies.append(time.perf_counter() - t0) + + with ThreadPoolExecutor(max_workers=threadCount) as ex: + futures = [ex.submit(worker) for _ in range(threadCount)] + for f in as_completed(futures): + f.result() + latencies.sort() + return errors, latencies + + def test_50_threads_x_20_reads_no_errors(self, liveConnector): + """T1a — STRESS: 50 threads × 20 reads each → 0 errors. + + Pre-pool, this scenario produced either + `OperationalError: another command is already in progress` or a + deadlock in `recv()` because the threadpool shared one psycopg2 + socket. With the pool plus `borrowConn`'s bounded wait, every + thread eventually gets a connection and completes — even with 30 + threads queued waiting at any moment (pool max=20). + """ + errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20) + assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}" + + def test_20_threads_x_50_reads_latency_budget(self, liveConnector): + """T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s. + + 20 threads matches the pool's `max=20` so there is no queueing — + every borrow returns immediately. This pins a sanity-level per-call + latency budget; pre-pool it was unbounded (recv() never returned). + + The 5 s ceiling is generous on purpose: `getRecordset` calls + `_ensureTableExists` which runs two `information_schema` queries + for column-additive migration, and under 20-way concurrency on a + single Postgres instance that produces a long tail. The hard + assertion is `not errors` — the latency check just guarantees + nothing hangs indefinitely. + """ + errors, latencies = self._runWorkers( + liveConnector, threadCount=20, callsPerThread=50 + ) + assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}" + p99 = latencies[int(len(latencies) * 0.99)] + assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget" + + def test_interleaved_load_and_save_no_collision(self, liveConnector): + """T2: parallel reads + writes on the same connector → no cursor mix-up. + + Pre-pool the reader could observe a row in mid-write or vice versa + because both shared the same cursor. With one connection per borrow, + the database's own row-locking is the only contention, and we just + need to assert no exceptions. + """ + stopFlag = threading.Event() + errors: list = [] + lock = threading.Lock() + + def reader(): + while not stopFlag.is_set(): + try: + liveConnector.getRecord(PoolTestRow, "seed") + except Exception as e: # noqa: BLE001 + with lock: + errors.append(("read", e)) + + def writer(): + i = 0 + while not stopFlag.is_set(): + try: + liveConnector.recordModify( + PoolTestRow, + "seed", + {"id": "seed", "payload": f"v{i}"}, + ) + i += 1 + except Exception as e: # noqa: BLE001 + with lock: + errors.append(("write", e)) + + threads = [ + threading.Thread(target=reader, daemon=True), + threading.Thread(target=reader, daemon=True), + threading.Thread(target=writer, daemon=True), + threading.Thread(target=writer, daemon=True), + ] + for t in threads: + t.start() + time.sleep(2.0) + stopFlag.set() + for t in threads: + t.join(timeout=3.0) + + assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}" + + def test_statement_timeout_releases_connection(self, liveConnector): + """T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact. + + The bug we are guarding against: a runaway query with no timeout + hung `recv()` forever, the psycopg2 connection was poisoned, and the + whole backend became unresponsive once that connection was reused. + With `statement_timeout=30000` configured at pool construction the + query is cancelled by the server, the borrow context manager rolls + back, and the connection returns to the pool — proven by the fact + that a follow-up call still succeeds quickly. + """ + # Use a short timeout to keep the test fast — override the pool's + # session statement_timeout for one borrow via SET LOCAL. + with liveConnector.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute("SET LOCAL statement_timeout = 500") + with pytest.raises(psycopg2.errors.QueryCanceled): + cursor.execute("SELECT pg_sleep(5)") + + # Follow-up call must succeed quickly: connection is back in the pool. + t0 = time.perf_counter() + rows = liveConnector.getRecordset(PoolTestRow) + elapsed = time.perf_counter() - t0 + assert any(r["id"] == "seed" for r in rows) + assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged" + + +class TestPoolRegistry: + def test_one_pool_per_database_identity(self, liveConnector): + """Two connectors against the same (host, db, port) share one pool.""" + cfg = _DB_CFG + pool1 = _PoolRegistry.getPool( + dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase, + dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"], + ) + pool2 = _PoolRegistry.getPool( + dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase, + dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"], + ) + assert pool1 is pool2 + + def test_close_all_clears_registry(self, liveConnector): + """`closeAllPools()` empties the registry so the next call rebuilds.""" + # Touch the pool first. + liveConnector.getRecordset(PoolTestRow) + assert _PoolRegistry._pools, "pool should exist after a real call" + closeAllPools() + assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()" diff --git a/tests/unit/interfaces/test_folderRbac.py b/tests/unit/interfaces/test_folderRbac.py index 049f392d..f4b984aa 100644 --- a/tests/unit/interfaces/test_folderRbac.py +++ b/tests/unit/interfaces/test_folderRbac.py @@ -68,6 +68,16 @@ class _FakeDb: def _ensureTableExists(self, modelClass): return True + def borrowCursor(self): + """Mimic `DatabaseConnector.borrowCursor()` context manager.""" + from contextlib import contextmanager + from unittest.mock import MagicMock + + @contextmanager + def _cm(): + yield MagicMock() + return _cm() + def seed(self, modelClass, record: Dict[str, Any]): tableName = modelClass.__name__ self._tables.setdefault(tableName, {}) diff --git a/tests/unit/routes/test_folder_crud.py b/tests/unit/routes/test_folder_crud.py index 86eaf480..66bad903 100644 --- a/tests/unit/routes/test_folder_crud.py +++ b/tests/unit/routes/test_folder_crud.py @@ -69,6 +69,16 @@ class _FakeDb: def _ensureTableExists(self, modelClass): return True + def borrowCursor(self): + """Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test.""" + from contextlib import contextmanager + from unittest.mock import MagicMock + + @contextmanager + def _cm(): + yield MagicMock() + return _cm() + def seed(self, modelClass, record: Dict[str, Any]): tableName = modelClass.__name__ self._tables.setdefault(tableName, {}) From 4064ac0266b67e77625a0726fec1f918f7ac1c88 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Mon, 18 May 2026 07:56:53 +0200 Subject: [PATCH 2/6] fixed toggle icons udb --- app.py | 3 + modules/datamodels/datamodelBackgroundJob.py | 11 + modules/datamodels/datamodelDataSource.py | 41 ++- .../datamodels/datamodelFeatureDataSource.py | 29 +- .../trustee/accounting/accountingDataSync.py | 9 +- modules/features/trustee/mainTrustee.py | 21 ++ .../features/trustee/routeFeatureTrustee.py | 24 +- .../workspace/routeFeatureWorkspace.py | 1 + modules/routes/routeDataSources.py | 267 ++++++++++++-- modules/routes/routeJobs.py | 18 +- modules/routes/routeRagInventory.py | 73 +++- .../mainBackgroundJobService.py | 40 +- .../services/serviceChat/mainServiceChat.py | 10 +- .../serviceKnowledge/_costEstimate.py | 86 +++++ .../serviceKnowledge/_inheritFlags.py | 342 ++++++++++++++++++ .../serviceKnowledge/_progressMessages.py | 23 ++ .../services/serviceKnowledge/_ragLimits.py | 107 ++++++ .../subConnectorIngestConsumer.py | 43 ++- .../subConnectorSyncClickup.py | 27 +- .../subConnectorSyncGdrive.py | 31 +- .../serviceKnowledge/subConnectorSyncGmail.py | 6 +- .../subConnectorSyncKdrive.py | 31 +- .../subConnectorSyncOutlook.py | 6 +- .../subConnectorSyncSharepoint.py | 36 +- .../serviceKnowledge/subPolicyResolver.py | 70 +--- modules/shared/i18nRegistry.py | 42 +++ scripts/debug_rag_job_result.py | 70 ++++ ..._db_migrate_backgroundjob_progress_data.py | 97 +++++ .../script_db_migrate_datasource_inherit.py | 110 ++++++ .../script_db_migrate_datasource_settings.py | 102 ++++++ tests/unit/services/test_costEstimate.py | 55 +++ tests/unit/services/test_inheritFlags.py | 330 +++++++++++++++++ .../test_knowledge_ingest_consumer.py | 39 +- tests/unit/services/test_ragLimits.py | 79 ++++ 34 files changed, 2107 insertions(+), 172 deletions(-) create mode 100644 modules/serviceCenter/services/serviceKnowledge/_costEstimate.py create mode 100644 modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py create mode 100644 modules/serviceCenter/services/serviceKnowledge/_progressMessages.py create mode 100644 modules/serviceCenter/services/serviceKnowledge/_ragLimits.py create mode 100644 scripts/debug_rag_job_result.py create mode 100644 scripts/script_db_migrate_backgroundjob_progress_data.py create mode 100644 scripts/script_db_migrate_datasource_inherit.py create mode 100644 scripts/script_db_migrate_datasource_settings.py create mode 100644 tests/unit/services/test_costEstimate.py create mode 100644 tests/unit/services/test_inheritFlags.py create mode 100644 tests/unit/services/test_ragLimits.py diff --git a/app.py b/app.py index d94c7dd5..93cc8b79 100644 --- a/app.py +++ b/app.py @@ -418,6 +418,9 @@ async def lifespan(app: FastAPI): registerKnowledgeIngestionConsumer, ) registerKnowledgeIngestionConsumer() + # Side-effect import: registers all walker progress message keys + # in the i18n registry so `syncRegistryToDb` picks them up. + from modules.serviceCenter.services.serviceKnowledge import _progressMessages # noqa: F401 except Exception as e: logger.warning(f"KnowledgeIngestionConsumer registration failed (non-critical): {e}") diff --git a/modules/datamodels/datamodelBackgroundJob.py b/modules/datamodels/datamodelBackgroundJob.py index fa99ea34..809fb994 100644 --- a/modules/datamodels/datamodelBackgroundJob.py +++ b/modules/datamodels/datamodelBackgroundJob.py @@ -96,6 +96,17 @@ class BackgroundJob(PowerOnModel): description="Human-readable current step (e.g. 'Importing journal entries...')", json_schema_extra={"label": "Fortschritts-Nachricht"}, ) + progressMessageData: Optional[Dict[str, Any]] = Field( + None, + description=( + "Structured i18n payload for `progressMessage`. Shape: " + "{'key': '', 'params': {...}}. " + "Frontend renders via `t(key, params)`; older clients fall back " + "to `progressMessage`. Single source of truth — keep `progressMessage` " + "as the rendered fallback in the producing language." + ), + json_schema_extra={"label": "Fortschritts-Nachricht (i18n)"}, + ) payload: Dict[str, Any] = Field( default_factory=dict, diff --git a/modules/datamodels/datamodelDataSource.py b/modules/datamodels/datamodelDataSource.py index fe3f0442..de32bdf3 100644 --- a/modules/datamodels/datamodelDataSource.py +++ b/modules/datamodels/datamodelDataSource.py @@ -62,9 +62,14 @@ class DataSource(PowerOnModel): description="Owner user ID", json_schema_extra={"label": "Benutzer-ID", "fk_target": {"db": "poweron_app", "table": "UserInDB", "labelField": "username"}}, ) - ragIndexEnabled: bool = Field( - default=False, - description="When true this tree element is indexed into the RAG knowledge store", + ragIndexEnabled: Optional[bool] = Field( + default=None, + description=( + "Three-state RAG indexing flag with cascade-inherit semantics. " + "None = inherit from nearest ancestor DataSource (path-traversal); " + "True/False = explicit override that propagates to descendants. " + "Walker computes effective value via getEffectiveFlag()." + ), json_schema_extra={"label": "Im RAG indexieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, ) lastIndexed: Optional[float] = Field( @@ -72,9 +77,13 @@ class DataSource(PowerOnModel): description="Timestamp of last successful RAG indexing run", json_schema_extra={"label": "Letzte Indexierung", "frontend_type": "timestamp"}, ) - scope: str = Field( - default="personal", - description="Data visibility scope: personal, featureInstance, mandate, global", + scope: Optional[str] = Field( + default=None, + description=( + "Data visibility scope with inherit semantics. " + "None = inherit; values: personal, featureInstance, mandate, global. " + "Cascade-reset on parent toggle." + ), json_schema_extra={"label": "Sichtbarkeit", "frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ {"value": "personal", "label": "Persönlich"}, {"value": "featureInstance", "label": "Feature-Instanz"}, @@ -82,11 +91,25 @@ class DataSource(PowerOnModel): {"value": "global", "label": "Global"}, ]}, ) - neutralize: bool = Field( - default=False, - description="Whether this data source should be neutralized before AI processing", + neutralize: Optional[bool] = Field( + default=None, + description=( + "Three-state neutralization flag with cascade-inherit semantics. " + "None = inherit from nearest ancestor DataSource (path-traversal); " + "True/False = explicit override that propagates to descendants." + ), json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, ) + settings: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "DataSource-scoped settings (JSON). Currently used keys: " + "ragLimits.{maxBytes,maxFileSize,maxItems,maxDepth}. " + "Walker reads these directly; missing keys fall back to RAG_LIMITS_DEFAULT " + "and are lazily persisted on next bootstrap." + ), + json_schema_extra={"label": "Einstellungen", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False}, + ) class ExternalEntry(BaseModel): diff --git a/modules/datamodels/datamodelFeatureDataSource.py b/modules/datamodels/datamodelFeatureDataSource.py index dd2c4035..f07a8bda 100644 --- a/modules/datamodels/datamodelFeatureDataSource.py +++ b/modules/datamodels/datamodelFeatureDataSource.py @@ -6,7 +6,7 @@ A FeatureDataSource links a FeatureInstance table (DATA_OBJECT) to a workspace so the agent can query structured feature data (e.g. TrusteePosition rows). """ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field from modules.datamodels.datamodelBase import PowerOnModel from modules.shared.i18nRegistry import i18nModel @@ -55,9 +55,12 @@ class FeatureDataSource(PowerOnModel): description="Workspace feature instance where this source is used", json_schema_extra={"label": "Workspace", "fk_target": {"db": "poweron_app", "table": "FeatureInstance", "labelField": "label"}}, ) - scope: str = Field( - default="personal", - description="Data visibility scope: personal, featureInstance, mandate, global", + scope: Optional[str] = Field( + default=None, + description=( + "Data visibility scope with inherit semantics. " + "None = inherit; values: personal, featureInstance, mandate, global." + ), json_schema_extra={"label": "Sichtbarkeit", "frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ {"value": "personal", "label": "Persönlich"}, {"value": "featureInstance", "label": "Feature-Instanz"}, @@ -65,9 +68,12 @@ class FeatureDataSource(PowerOnModel): {"value": "global", "label": "Global"}, ]}, ) - neutralize: bool = Field( - default=False, - description="Whether this data source should be neutralized before AI processing", + neutralize: Optional[bool] = Field( + default=None, + description=( + "Three-state neutralization flag with cascade-inherit semantics. " + "None = inherit; True/False = explicit. Cascade-reset on parent toggle." + ), json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, ) neutralizeFields: Optional[List[str]] = Field( @@ -80,3 +86,12 @@ class FeatureDataSource(PowerOnModel): description="Record-level filter applied when querying this table, e.g. {'sessionId': 'abc-123'}", json_schema_extra={"label": "Datensatzfilter"}, ) + settings: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "FeatureDataSource-scoped settings (JSON). Currently used keys: " + "ragLimits.{maxBytes,maxFileSize,maxItems,maxDepth}. " + "Mirror of DataSource.settings so the UDB settings modal can target both." + ), + json_schema_extra={"label": "Einstellungen", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False}, + ) diff --git a/modules/features/trustee/accounting/accountingDataSync.py b/modules/features/trustee/accounting/accountingDataSync.py index 5827dd11..db50d657 100644 --- a/modules/features/trustee/accounting/accountingDataSync.py +++ b/modules/features/trustee/accounting/accountingDataSync.py @@ -205,11 +205,16 @@ class AccountingDataSync: boundary so the UI poll on ``GET /api/jobs/{jobId}`` shows real movement instead of jumping from 10 % to 100 %. Safe to omit. """ - def _progress(pct: int, msg: str) -> None: + def _progress(pct: int, msgKey: str, msgParams: Optional[Dict[str, Any]] = None) -> None: + """Forward to progressCb using the i18n contract. + + `msgKey` is the German plaintext-as-key; the frontend translates + it via `t(key, params)` when rendering. + """ if progressCb is None: return try: - progressCb(pct, msg) + progressCb(pct, messageKey=msgKey, messageParams=msgParams or {}) except Exception as ex: logger.warning(f"progressCb failed at {pct}%: {ex}") from modules.features.trustee.datamodelFeatureTrustee import ( diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py index 8f725d2f..b3f7cdcf 100644 --- a/modules/features/trustee/mainTrustee.py +++ b/modules/features/trustee/mainTrustee.py @@ -12,6 +12,27 @@ from modules.shared.i18nRegistry import t logger = logging.getLogger(__name__) +# i18n: register BackgroundJob progress message keys used by routeFeatureTrustee / +# accountingDataSync. Walker call sites use `progressCb(..., messageKey="…")` +# without going through `t()`, so we must register each key here as a +# string-literal `t(...)` call -- per i18n convention `t()` MUST receive a +# literal so static scanners and the boot-time `syncRegistryToDb` can pick +# it up. Do NOT collapse these into a loop over a list of variables. +t("Sync wird vorbereitet ({total} Position(en))...") +t("Verbindungsaufbau fehlgeschlagen.") +t("Keine aktive Buchhaltungs-Konfiguration gefunden.") +t("Position {index}/{total} verarbeitet") +t("Sync abgeschlossen.") +t("Initialisiere Import...") +t("Verbinde mit Buchhaltungssystem...") +t("Import abgeschlossen.") +t("Lade Kontenplan...") +t("Lade Journaleintraege vom Buchhaltungssystem...") +t("Lade Kunden...") +t("Lade Lieferanten...") +t("Lade Kontensaldi vom Buchhaltungssystem...") +t("Speichere Kontensaldi...") + # Feature metadata FEATURE_CODE = "trustee" FEATURE_LABEL = t("Treuhand", context="UI") diff --git a/modules/features/trustee/routeFeatureTrustee.py b/modules/features/trustee/routeFeatureTrustee.py index 2c9c3328..a71b508f 100644 --- a/modules/features/trustee/routeFeatureTrustee.py +++ b/modules/features/trustee/routeFeatureTrustee.py @@ -1644,7 +1644,11 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D results = [] total = len(positionIds) - progressCb(2, f"Sync wird vorbereitet ({total} Position(en))...") + progressCb( + 2, + messageKey="Sync wird vorbereitet ({total} Position(en))...", + messageParams={"total": total}, + ) # Resolve connector + plain config once to avoid decryption rate-limits # (mirrors the optimisation in pushBatchToAccounting). We push positions @@ -1655,12 +1659,12 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D connector, plainConfig, configRecord = await bridge._resolveConnectorAndConfig(instanceId) except Exception as resolveErr: logger.exception("Accounting push: failed to resolve connector/config") - progressCb(100, "Verbindungsaufbau fehlgeschlagen.") + progressCb(100, messageKey="Verbindungsaufbau fehlgeschlagen.") raise resolveErr if not connector or not plainConfig: results = [SyncResult(success=False, errorMessage="No active accounting configuration found") for _ in positionIds] - progressCb(100, "Keine aktive Buchhaltungs-Konfiguration gefunden.") + progressCb(100, messageKey="Keine aktive Buchhaltungs-Konfiguration gefunden.") return { "total": len(results), "success": 0, @@ -1680,7 +1684,11 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D results.append(result) # Reserve 5..95% for the push loop, keep the tail for summary. pct = 5 + int(90 * index / total) - progressCb(pct, f"Position {index}/{total} verarbeitet") + progressCb( + pct, + messageKey="Position {index}/{total} verarbeitet", + messageParams={"index": index, "total": total}, + ) skipped = [r for r in results if not r.success and r.errorMessage and "already synced" in r.errorMessage] failed = [r for r in results if not r.success and r not in skipped] @@ -1693,7 +1701,7 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D "; ".join(r.errorMessage or "unknown" for r in failed[:3]), ) - progressCb(100, "Sync abgeschlossen.") + progressCb(100, messageKey="Sync abgeschlossen.") return { "total": len(results), "success": sum(1 for r in results if r.success), @@ -1823,10 +1831,10 @@ async def _trusteeAccountingSyncJobHandler(job: Dict[str, Any], progressCb) -> D payload = job.get("payload") or {} rootUser = getRootUser() - progressCb(5, "Initialisiere Import...") + progressCb(5, messageKey="Initialisiere Import...") interface = getInterface(rootUser, mandateId=mandateId, featureInstanceId=instanceId) sync = AccountingDataSync(interface) - progressCb(10, "Verbinde mit Buchhaltungssystem...") + progressCb(10, messageKey="Verbinde mit Buchhaltungssystem...") result = await sync.importData( featureInstanceId=instanceId, mandateId=mandateId, @@ -1834,7 +1842,7 @@ async def _trusteeAccountingSyncJobHandler(job: Dict[str, Any], progressCb) -> D dateTo=payload.get("dateTo"), progressCb=progressCb, ) - progressCb(100, "Import abgeschlossen.") + progressCb(100, messageKey="Import abgeschlossen.") return result diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py index 4487e5fe..2fa788e8 100644 --- a/modules/features/workspace/routeFeatureWorkspace.py +++ b/modules/features/workspace/routeFeatureWorkspace.py @@ -1324,6 +1324,7 @@ async def listWorkspaceConnections( "externalUsername": conn.get("externalUsername"), "externalEmail": conn.get("externalEmail"), "status": status, + "knowledgeIngestionEnabled": bool(conn.get("knowledgeIngestionEnabled")), }) return JSONResponse({"connections": items}) diff --git a/modules/routes/routeDataSources.py b/modules/routes/routeDataSources.py index ba398008..5dec19c8 100644 --- a/modules/routes/routeDataSources.py +++ b/modules/routes/routeDataSources.py @@ -9,11 +9,40 @@ from fastapi import APIRouter, HTTPException, Depends, Path, Request, Body from modules.auth import limiter, getRequestContext, RequestContext from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource +from modules.datamodels.datamodelUam import UserConnection from modules.shared.i18nRegistry import apiRouteContext routeApiMsg = apiRouteContext("routeDataSources") logger = logging.getLogger(__name__) + +def _ensureConnectionKnowledgeFlag(rootIf, connectionId: str) -> None: + """Forward-only sync: if a DataSource gets RAG-activated, ensure the parent + UserConnection.knowledgeIngestionEnabled is true. + + Intentionally NOT bidirectional: disabling the last DataSource does NOT + auto-clear knowledgeIngestionEnabled, because the consent flag may have + been set explicitly via the Connections page / wizard even before any + DataSource exists. Only the master switch (`/knowledge-consent`) may + clear it. + """ + if not connectionId: + return + try: + currentConn = rootIf.db.getRecord(UserConnection, connectionId) + if not currentConn: + return + if bool(currentConn.get("knowledgeIngestionEnabled")): + return + rootIf.db.recordModify(UserConnection, connectionId, {"knowledgeIngestionEnabled": True}) + logger.info( + "Auto-enabled knowledgeIngestionEnabled on UserConnection %s " + "(triggered by first active DataSource).", + connectionId, + ) + except Exception as e: + logger.warning("Could not auto-enable knowledgeIngestionEnabled for connection %s: %s", connectionId, e) + router = APIRouter( prefix="/api/datasources", tags=["Data Sources"], @@ -45,26 +74,43 @@ def _findSourceRecord(db, sourceId: str): def _updateDataSourceScope( request: Request, sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"), - scope: str = Body(..., embed=True), + scope: Optional[str] = Body(None, embed=True), context: RequestContext = Depends(getRequestContext), ) -> Dict[str, Any]: - """Update the scope of a DataSource or FeatureDataSource. Global scope requires sysAdmin.""" - if scope not in _VALID_SCOPES: - raise HTTPException(status_code=400, detail=f"Invalid scope: {scope}. Must be one of {_VALID_SCOPES}") + """Update the scope of a DataSource. Cascade-resets explicit descendants. - if scope == "global" and not context.isSysAdmin: - raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope")) + `scope=None` resets this node to inherit (no cascade). Global scope + requires sysAdmin. + """ + if scope is not None: + if scope not in _VALID_SCOPES: + raise HTTPException(status_code=400, detail=f"Invalid scope: {scope}. Must be one of {_VALID_SCOPES}") + if scope == "global" and not context.isSysAdmin: + raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope")) try: from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + cascadeResetDescendants, + cascadeResetDescendantsFds, + ) rootIf = getRootInterface() rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") rootIf.db.recordModify(model, sourceId, {"scope": scope}) - logger.info("Updated scope=%s for %s %s", scope, model.__name__, sourceId) - return {"sourceId": sourceId, "scope": scope, "updated": True} + cascaded = 0 + if scope is not None: + if model is DataSource: + cascaded = cascadeResetDescendants(rootIf, rec, "scope") + else: + cascaded = cascadeResetDescendantsFds(rootIf, rec, "scope") + logger.info( + "Updated scope=%s for %s %s (cascade-reset %d descendants)", + scope, model.__name__, sourceId, cascaded, + ) + return {"sourceId": sourceId, "scope": scope, "updated": True, "cascadedDescendants": cascaded} except HTTPException: raise except Exception as e: @@ -77,20 +123,36 @@ def _updateDataSourceScope( def _updateDataSourceNeutralize( request: Request, sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"), - neutralize: bool = Body(..., embed=True), + neutralize: Optional[bool] = Body(None, embed=True), context: RequestContext = Depends(getRequestContext), ) -> Dict[str, Any]: - """Toggle the neutralization flag on a DataSource or FeatureDataSource.""" + """Set neutralize flag on a DataSource. Cascade-resets explicit descendants. + + `neutralize=None` resets this node to inherit (no cascade). + """ try: from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + cascadeResetDescendants, + cascadeResetDescendantsFds, + ) rootIf = getRootInterface() rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize}) - logger.info("Updated neutralize=%s for %s %s", neutralize, model.__name__, sourceId) - return {"sourceId": sourceId, "neutralize": neutralize, "updated": True} + cascaded = 0 + if neutralize is not None: + if model is DataSource: + cascaded = cascadeResetDescendants(rootIf, rec, "neutralize") + else: + cascaded = cascadeResetDescendantsFds(rootIf, rec, "neutralize") + logger.info( + "Updated neutralize=%s for %s %s (cascade-reset %d descendants)", + neutralize, model.__name__, sourceId, cascaded, + ) + return {"sourceId": sourceId, "neutralize": neutralize, "updated": True, "cascadedDescendants": cascaded} except HTTPException: raise except Exception as e: @@ -132,13 +194,14 @@ def _updateNeutralizeFields( async def _updateDataSourceRagIndex( request: Request, sourceId: str = Path(..., description="ID of the DataSource"), - ragIndexEnabled: bool = Body(..., embed=True), + ragIndexEnabled: Optional[bool] = Body(None, embed=True), context: RequestContext = Depends(getRequestContext), ) -> Dict[str, Any]: - """Toggle RAG indexing for a DataSource. + """Set RAG indexing flag on a DataSource. Cascade-resets explicit descendants. - true: sets flag + enqueues mini-bootstrap for this DataSource only. - false: sets flag + synchronously purges all chunks from this DataSource. + `ragIndexEnabled=None` resets this node to inherit (no cascade, no purge, + no bootstrap — the node simply follows its ancestor chain afterwards). + `True` enqueues a mini-bootstrap. `False` synchronously purges chunks. Must be `async def` so `await startJob(...)` registers `_runJob` in the main event loop. Sync route → worker thread → temporary loop closes @@ -146,18 +209,26 @@ async def _updateDataSourceRagIndex( """ try: from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import cascadeResetDescendants rootIf = getRootInterface() rec = rootIf.db.getRecord(DataSource, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") rootIf.db.recordModify(DataSource, sourceId, {"ragIndexEnabled": ragIndexEnabled}) - logger.info("Updated ragIndexEnabled=%s for DataSource %s", ragIndexEnabled, sourceId) + cascaded = 0 + if ragIndexEnabled is not None: + cascaded = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled") + logger.info( + "Updated ragIndexEnabled=%s for DataSource %s (cascade-reset %d descendants)", + ragIndexEnabled, sourceId, cascaded, + ) - if ragIndexEnabled: + connectionId = rec.get("connectionId") or rec.get("connection_id") or "" + if ragIndexEnabled is True: + _ensureConnectionKnowledgeFlag(rootIf, connectionId) from modules.serviceCenter.services.serviceBackgroundJobs import startJob - connectionId = rec.get("connectionId") or rec.get("connection_id") or "" conn = rootIf.getUserConnectionById(connectionId) if connectionId else None authority = "" if conn: @@ -168,7 +239,7 @@ async def _updateDataSourceRagIndex( {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]}, triggeredBy=str(context.user.id), ) - else: + elif ragIndexEnabled is False: from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId) logger.info("Purged %d index rows / %d chunks for DataSource %s", @@ -182,12 +253,164 @@ async def _updateDataSourceRagIndex( mandateId=context.mandateId, category=AuditCategory.PERMISSION.value, action="rag_index_toggled", - details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled}), + details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "cascadedDescendants": cascaded}), ) - return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True} + return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True, "cascadedDescendants": cascaded} except HTTPException: raise except Exception as e: logger.error("Error updating datasource ragIndexEnabled: %s", e) raise HTTPException(status_code=500, detail=str(e)) + + +_CLICKUP_SOURCE_TYPES = {"clickup", "clickupList", "clickupSpace", "clickupFolder"} +_ALLOWED_RAG_LIMIT_KEYS = { + "files": {"maxItems", "maxBytes", "maxFileSize", "maxDepth"}, + "clickup": {"maxTasks", "maxWorkspaces", "maxListsPerWorkspace"}, +} + + +def _kindForSource(rec: Dict[str, Any], model) -> str: + """Map a DataSource record to a RAG-limits kind ('files' or 'clickup'). + + FeatureDataSource (tables, not file walkers) reports as 'files' so the + same UI/limit shape works; the limits simply won't be consumed by any + walker today but are stored for forward-compat. + """ + if model is FeatureDataSource: + return "files" + sourceType = str(rec.get("sourceType") or "").strip() + return "clickup" if sourceType in _CLICKUP_SOURCE_TYPES else "files" + + +def _sanitizeRagLimits(kind: str, raw: Any) -> Dict[str, int]: + """Coerce an incoming ragLimits dict to {allowedKey: positive int}. + + Unknown keys are silently dropped; non-positive or non-numeric values + are rejected with 400. + """ + if not isinstance(raw, dict): + raise HTTPException(status_code=400, detail="ragLimits must be an object") + allowed = _ALLOWED_RAG_LIMIT_KEYS.get(kind, set()) + cleaned: Dict[str, int] = {} + for key, value in raw.items(): + if key not in allowed: + continue + try: + intValue = int(value) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail=f"ragLimits.{key} must be an integer") + if intValue <= 0: + raise HTTPException(status_code=400, detail=f"ragLimits.{key} must be > 0") + cleaned[key] = intValue + return cleaned + + +@router.patch("/{sourceId}/settings") +@limiter.limit("30/minute") +def _updateDataSourceSettings( + request: Request, + sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"), + settings: Dict[str, Any] = Body(..., embed=True), + context: RequestContext = Depends(getRequestContext), +) -> Dict[str, Any]: + """Replace `settings` on a DataSource or FeatureDataSource (partial merge per top-level key). + + Currently supports `ragLimits` only. Unknown top-level keys in the body are + rejected to avoid silently storing garbage that no consumer reads. + + Owner-only for personal DataSources; mandate/feature scopes additionally + accept the mandate or workspace admins of that scope. + """ + if not isinstance(settings, dict): + raise HTTPException(status_code=400, detail="settings must be an object") + unknown = set(settings.keys()) - {"ragLimits"} + if unknown: + raise HTTPException(status_code=400, detail=f"Unknown settings keys: {sorted(unknown)}") + + try: + from modules.interfaces.interfaceDbApp import getRootInterface + rootIf = getRootInterface() + rec, model = _findSourceRecord(rootIf.db, sourceId) + if not rec: + raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") + + ownerId = str(rec.get("userId") or "") + currentUserId = str(context.user.id) + if ownerId and ownerId != currentUserId and not context.isSysAdmin: + scope = str(rec.get("scope") or "personal") + isMandateAdmin = getattr(context, "isMandateAdmin", False) + if scope == "personal" or not isMandateAdmin: + raise HTTPException(status_code=403, detail="Not allowed to modify this DataSource's settings") + + kind = _kindForSource(rec, model) + + currentSettings = rec.get("settings") or {} + if not isinstance(currentSettings, dict): + currentSettings = {} + newSettings = dict(currentSettings) + + if "ragLimits" in settings: + cleanedLimits = _sanitizeRagLimits(kind, settings["ragLimits"]) + mergedLimits = dict(currentSettings.get("ragLimits") or {}) + mergedLimits.update(cleanedLimits) + newSettings["ragLimits"] = mergedLimits + + rootIf.db.recordModify(model, sourceId, {"settings": newSettings}) + + import json + from modules.shared.auditLogger import audit_logger + from modules.datamodels.datamodelAudit import AuditCategory + audit_logger.logEvent( + userId=currentUserId, + mandateId=context.mandateId, + category=AuditCategory.PERMISSION.value, + action="datasource_settings_changed", + details=json.dumps({ + "sourceId": sourceId, + "model": model.__name__, + "oldSettings": currentSettings, + "newSettings": newSettings, + }), + ) + logger.info("Updated settings on %s %s by user %s", model.__name__, sourceId, currentUserId) + return {"sourceId": sourceId, "settings": newSettings, "updated": True} + except HTTPException: + raise + except Exception as e: + logger.error("Error updating datasource settings: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/{sourceId}/cost-estimate") +@limiter.limit("60/minute") +def _getDataSourceCostEstimate( + request: Request, + sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"), + context: RequestContext = Depends(getRequestContext), +) -> Dict[str, Any]: + """Return an indicative full-sync cost estimate for the given DataSource. + + Uses the current effective ragLimits (DataSource.settings.ragLimits with + fallback to centralized defaults) as the basis. Returns the same + `{estimatedTokens, estimatedUsd, basis}` shape regardless of source kind. + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceKnowledge import _ragLimits, _costEstimate + rootIf = getRootInterface() + rec, model = _findSourceRecord(rootIf.db, sourceId) + if not rec: + raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") + + kind = _kindForSource(rec, model) + effective = _ragLimits.getRagLimits(rec, kind) + estimate = _costEstimate.estimateBootstrapCost(effective, kind=kind) + estimate["sourceId"] = sourceId + return estimate + except HTTPException: + raise + except Exception as e: + logger.error("Error computing cost estimate: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/routes/routeJobs.py b/modules/routes/routeJobs.py index d2124a0b..9cd89d46 100644 --- a/modules/routes/routeJobs.py +++ b/modules/routes/routeJobs.py @@ -21,7 +21,7 @@ from modules.serviceCenter.services.serviceBackgroundJobs import ( getJobStatus, listJobs, ) -from modules.shared.i18nRegistry import apiRouteContext +from modules.shared.i18nRegistry import apiRouteContext, resolveJobMessage logger = logging.getLogger(__name__) routeApiMsg = apiRouteContext("routeJobs") @@ -34,8 +34,20 @@ router = APIRouter( def _serialiseJob(job: Dict[str, Any]) -> Dict[str, Any]: - """Strip system audit fields and ensure JSON-safe types.""" - return {k: v for k, v in job.items() if not k.startswith("sys")} + """Strip system audit fields, ensure JSON-safe types, translate progress. + + Walkers store progress as a structured payload (``progressMessageData = + {key, params}``). The frontend never calls ``t()`` on backend-supplied + keys (i18n convention #2), so we resolve the payload here using the + request-context language and overwrite ``progressMessage`` with the + fully rendered string. Older clients keep working because they read + the same field. + """ + out = {k: v for k, v in job.items() if not k.startswith("sys")} + translated = resolveJobMessage(out.get("progressMessageData")) + if translated: + out["progressMessage"] = translated + return out def _userHasMandateAccess(context: RequestContext, mandateId: Optional[str]) -> bool: diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py index 7c426d77..99d5c4df 100644 --- a/modules/routes/routeRagInventory.py +++ b/modules/routes/routeRagInventory.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional from fastapi import APIRouter, HTTPException, Depends, Request from modules.auth import limiter, getCurrentUser, getRequestContext, RequestContext from modules.datamodels.datamodelUam import User -from modules.shared.i18nRegistry import apiRouteContext +from modules.shared.i18nRegistry import apiRouteContext, resolveJobMessage routeApiMsg = apiRouteContext("routeRagInventory") logger = logging.getLogger(__name__) @@ -24,6 +24,53 @@ router = APIRouter( ) +_SUB_RESULT_KEYS = ("sharepoint", "outlook", "drive", "gmail", "clickup", "kdrive") + + +def _flattenJobResult(result: Dict[str, Any]) -> Dict[str, Any]: + """Bootstrap handlers nest per-service results (e.g. msft returns + `{"sharepoint": {...}, "outlook": {...}}`). The UI needs per-connection + aggregates AND the first hit limit, so we sum the counters and pick the + most informative `stoppedAtLimit` across sub-services. + + Returns a flat dict with the same keys the UI expects on `lastSuccess`. + """ + subResults = [result[k] for k in _SUB_RESULT_KEYS if isinstance(result.get(k), dict)] + if not subResults: + # Single-service handler that returns flat dict directly (legacy path). + return result + + indexed = sum(int(r.get("indexed") or 0) for r in subResults) + skippedDup = sum(int(r.get("skippedDuplicate") or 0) for r in subResults) + skippedPol = sum(int(r.get("skippedPolicy") or 0) for r in subResults) + failed = sum(int(r.get("failed") or 0) for r in subResults) + bytes_ = sum(int(r.get("bytesProcessed") or 0) for r in subResults) + # Parallel sub-services: wall-clock ≈ slowest one. + durationMs = max((int(r.get("durationMs") or 0) for r in subResults), default=0) + + # First sub-service that hit a limit wins — UI shows one banner per + # connection; if multiple stopped, the first one is informative enough + # and the user re-runs after raising that budget. + stoppedAtLimit: Optional[str] = None + limits: Dict[str, Any] = {} + for r in subResults: + if r.get("stoppedAtLimit"): + stoppedAtLimit = r["stoppedAtLimit"] + limits = r.get("limits") or {} + break + + return { + "indexed": indexed, + "skippedDuplicate": skippedDup, + "skippedPolicy": skippedPol, + "failed": failed, + "bytesProcessed": bytes_, + "durationMs": durationMs, + "stoppedAtLimit": stoppedAtLimit, + "limits": limits, + } + + def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]: """Build per-connection RAG inventory rows. @@ -111,7 +158,17 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L jobs = jobService.listJobs(jobType="connection.bootstrap", limit=50) connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId] runningJobs = [ - {"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")} + { + "jobId": j["id"], + "progress": j.get("progress", 0), + # Server-side translate the structured walker payload into + # the request-context language; frontend renders 1:1 (no + # `t()` on backend-supplied keys). + "progressMessage": ( + resolveJobMessage(j.get("progressMessageData")) + or j.get("progressMessage", "") + ), + } for j in connJobs if j.get("status") in ("PENDING", "RUNNING") ] @@ -126,7 +183,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "finishedAt": j.get("finishedAt"), } elif status == "SUCCESS" and lastSuccess is None: - result = j.get("result") or {} + # Bootstrap handlers may return either a flat dict (single + # service) or a nested dict keyed by sub-service (e.g. msft + # returns {"sharepoint": {...}, "outlook": {...}}). Flatten + # so the UI always sees aggregated counters and the first + # sub-service that hit a limit. + result = _flattenJobResult(j.get("result") or {}) lastSuccess = { "jobId": j["id"], "finishedAt": j.get("finishedAt"), @@ -337,7 +399,10 @@ def _getActiveJobs( "connectionLabel": getattr(conn, "displayLabel", None) or getattr(conn, "authority", connId), "jobType": j.get("jobType", "connection.bootstrap"), "progress": j.get("progress", 0), - "progressMessage": j.get("progressMessage", ""), + "progressMessage": ( + resolveJobMessage(j.get("progressMessageData")) + or j.get("progressMessage", "") + ), }) return active except Exception as e: diff --git a/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py b/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py index e27dae58..90b69bce 100644 --- a/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py +++ b/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py @@ -54,19 +54,53 @@ _CANCEL_CHECK_INTERVAL_S = 3.0 class JobProgressCallback: - """Callable progress reporter with cooperative cancel-check for long-running walkers.""" + """Callable progress reporter with cooperative cancel-check for long-running walkers. + + Two ways to set a progress message: + progressCb(50, "145 Dateien verarbeitet") # legacy plaintext (DE) + progressCb(50, messageKey="{n} Dateien verarbeitet", + messageParams={"n": 145}) # i18n-friendly + + When `messageKey` is given the structured payload is written to + `BackgroundJob.progressMessageData` so the frontend can render it via + `t(key, params)` in the user's UI language. A best-effort rendered + fallback is also stored in `progressMessage` for older clients, logs, + and audit trails. + """ def __init__(self, jobId: str): self._jobId = jobId self._cancelledCache: Optional[bool] = None self._lastCheckedAt: float = 0.0 - def __call__(self, progress: int, message: Optional[str] = None) -> None: + def __call__( + self, + progress: int, + message: Optional[str] = None, + *, + messageKey: Optional[str] = None, + messageParams: Optional[Dict[str, Any]] = None, + ) -> None: try: clamped = max(0, min(100, int(progress))) fields: Dict[str, Any] = {"progress": clamped} - if message is not None: + + if messageKey is not None: + params = messageParams or {} + try: + fallback = messageKey.format(**params) + except (KeyError, IndexError, ValueError) as fmtErr: + fallback = message or messageKey + logger.warning( + "progressCb message format failed for job %s key=%r params=%r: %s", + self._jobId, messageKey, params, fmtErr, + ) + fields["progressMessageData"] = {"key": messageKey, "params": params} + fields["progressMessage"] = (message or fallback)[:500] + elif message is not None: fields["progressMessage"] = message[:500] + fields["progressMessageData"] = None + _updateJob(self._jobId, fields) except Exception as ex: logger.warning("Progress update failed for job %s: %s", self._jobId, ex) diff --git a/modules/serviceCenter/services/serviceChat/mainServiceChat.py b/modules/serviceCenter/services/serviceChat/mainServiceChat.py index 2ca61d7e..61026de0 100644 --- a/modules/serviceCenter/services/serviceChat/mainServiceChat.py +++ b/modules/serviceCenter/services/serviceChat/mainServiceChat.py @@ -534,11 +534,17 @@ class ChatService: ) -> Dict[str, Any]: """Create a new external data source reference. - Returns existing record if connectionId + path already exists (upsert semantics). + Upsert key is `(connectionId, sourceType, path)`. The same `path='/'` + can carry multiple DataSources discriminated by sourceType: the + Connection-Root (sourceType=, e.g. 'msft') plus one per + service (sourceType='sharepointFolder', 'outlookFolder', ...). The + sourceType filter MUST be present, otherwise a Service-Root POST + returns the Connection-Root and toggles cascade onto every sibling. """ from modules.datamodels.datamodelDataSource import DataSource existing = self.interfaceDbApp.db.getRecordset( - DataSource, recordFilter={"connectionId": connectionId, "path": path} + DataSource, + recordFilter={"connectionId": connectionId, "sourceType": sourceType, "path": path}, ) if existing: return existing[0] if isinstance(existing[0], dict) else existing[0].model_dump() diff --git a/modules/serviceCenter/services/serviceKnowledge/_costEstimate.py b/modules/serviceCenter/services/serviceKnowledge/_costEstimate.py new file mode 100644 index 00000000..565c219d --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/_costEstimate.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Indicative cost estimation for a RAG bootstrap run. + +This is **not** a billing-grade forecast: it gives the user a back-of-the-envelope +USD figure for the worst-case full sync, so they can sanity-check before raising +`maxBytes`/`maxItems`. The output always carries the underlying assumptions +(`basis`) so the user can judge plausibility. + +Heuristic: + estimatedTokens = ceil(maxBytes / CHARS_PER_TOKEN_BYTES_FACTOR) + estimatedUsd = estimatedTokens / 1_000_000 * EMBEDDING_USD_PER_MTOKEN + +Defaults match OpenAI `text-embedding-3-small` pricing (2026-Q2). +""" + +from __future__ import annotations + +import math +from typing import Any, Dict + + +CHARS_PER_TOKEN = 4 +EMBEDDING_USD_PER_MTOKEN = 0.02 +DEFAULT_TOKENS_PER_ITEM = 1500 +BYTES_PER_TOKEN_TEXT_FACTOR = 4 +EXTRACTABLE_FRACTION = 0.4 + + +def estimateBootstrapCost(limits: Dict[str, int], kind: str = "files") -> Dict[str, Any]: + """Return an indicative cost estimate dict for a DataSource bootstrap. + + Returned shape:: + + { + "estimatedTokens": int, + "estimatedUsd": float, # rounded to 4 decimals + "basis": { + "kind": "files"|"clickup", + "limits": {...}, + "assumptions": { + "embeddingUsdPerMToken": 0.02, + "charsPerToken": 4, + "extractableFraction": 0.4, + "tokensPerItem": 1500 # only for clickup-like item counts + }, + "notes": "non-binding, depends on real file content..." + } + } + """ + assumptions: Dict[str, Any] = { + "embeddingUsdPerMToken": EMBEDDING_USD_PER_MTOKEN, + "charsPerToken": CHARS_PER_TOKEN, + } + + if kind == "files": + maxBytes = int(limits.get("maxBytes") or 0) + extractableBytes = maxBytes * EXTRACTABLE_FRACTION + estimatedTokens = int(math.ceil(extractableBytes / BYTES_PER_TOKEN_TEXT_FACTOR)) + assumptions["extractableFraction"] = EXTRACTABLE_FRACTION + assumptions["formula"] = "ceil(maxBytes * 0.4 / 4)" + elif kind == "clickup": + maxTasks = int(limits.get("maxTasks") or 0) + maxWorkspaces = max(1, int(limits.get("maxWorkspaces") or 1)) + estimatedTokens = maxTasks * maxWorkspaces * DEFAULT_TOKENS_PER_ITEM + assumptions["tokensPerItem"] = DEFAULT_TOKENS_PER_ITEM + assumptions["formula"] = "maxTasks * maxWorkspaces * 1500" + else: + estimatedTokens = 0 + assumptions["formula"] = "unknown kind, returning zero" + + estimatedUsd = round(estimatedTokens / 1_000_000 * EMBEDDING_USD_PER_MTOKEN, 4) + + return { + "estimatedTokens": estimatedTokens, + "estimatedUsd": estimatedUsd, + "basis": { + "kind": kind, + "limits": dict(limits), + "assumptions": assumptions, + "notes": ( + "Indicative only. Actual cost depends on file types, extractable text " + "ratio, dedup hit-rate, retries, and current embedding model pricing." + ), + }, + } diff --git a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py new file mode 100644 index 00000000..00180c9f --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py @@ -0,0 +1,342 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled, scope). + +Three-state flags allow tree elements to either set an explicit value or +inherit the value from their nearest ancestor in the path hierarchy. The +walker (RAG/Neutralize) and routes resolve the *effective* value; the cascade +helper resets explicit descendant values when a parent is toggled. + +Path-traversal rules: +- A DataSource is identified by `(connectionId, sourceType, path)`. +- The root of a service tree is `path == '/'`. +- Sub-elements have paths like `/folder1/sub`. Their parent path is the + longest prefix path that exists as a DataSource record (string-based). +- If no ancestor with an explicit value exists, the default is `False` + (or `'personal'` for scope) — matching the legacy behavior of NULL = inherit. +""" + +import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +_INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope") + +# Connection-root DataSources carry the authority as their sourceType +# (e.g. 'msft', 'google'). They sit one level above all service DataSources +# of the same connection in the visual tree, so flag inheritance must +# cross sourceType boundaries — but ONLY from these authority roots. +_AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"}) + + +def _normalisePath(path: Optional[str]) -> str: + """Normalize a DataSource path to '/'-prefixed, no trailing slash (except root).""" + if not path: + return "/" + p = str(path).strip() + if not p.startswith("/"): + p = "/" + p + if len(p) > 1 and p.endswith("/"): + p = p.rstrip("/") + return p + + +def _flagDefault(flag: str) -> Any: + if flag == "scope": + return "personal" + return False + + +def _isExplicit(value: Any) -> bool: + """A flag value is explicit when it is not None. + + Note: legacy rows may carry empty-string scope; treat as inherit too. + """ + if value is None: + return False + if isinstance(value, str) and value == "": + return False + return True + + +def _getRecordValue(rec: Any, key: str) -> Any: + if isinstance(rec, dict): + return rec.get(key) + return getattr(rec, key, None) + + +def _findAncestorChain( + rec: Dict[str, Any], + allDs: Iterable[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return all ancestor DataSources of `rec` in the same connection, + ordered nearest-first. + + Two ancestor relations are merged: + 1) **same-sourceType path-ancestor** — strict path-prefix within the + same service tree (sharepointFolder, gmailFolder, ...). + 2) **connection-root ancestor** — a DS with `path='/'` and + `sourceType` ∈ authority set (msft, google, ...) is the parent of + every other DS in that connection regardless of sourceType, so a + toggle on the connection node propagates to all services beneath. + + The connection-root is always the most distant ancestor and therefore + sorts after any same-sourceType ancestors. + """ + recPath = _normalisePath(_getRecordValue(rec, "path")) + recSourceType = _getRecordValue(rec, "sourceType") + recConnectionId = _getRecordValue(rec, "connectionId") + sameTypeCandidates: List[Tuple[int, Dict[str, Any]]] = [] + connectionRoot: Optional[Dict[str, Any]] = None + recIsConnectionRoot = recSourceType in _AUTHORITY_SOURCE_TYPES and recPath == "/" + for cand in allDs: + if _getRecordValue(cand, "id") == _getRecordValue(rec, "id"): + continue + if _getRecordValue(cand, "connectionId") != recConnectionId: + continue + candSourceType = _getRecordValue(cand, "sourceType") + candPath = _normalisePath(_getRecordValue(cand, "path")) + if candSourceType == recSourceType: + if candPath == recPath or not _isAncestorPath(candPath, recPath): + continue + sameTypeCandidates.append((len(candPath), cand)) + elif ( + not recIsConnectionRoot + and candSourceType in _AUTHORITY_SOURCE_TYPES + and candPath == "/" + ): + connectionRoot = cand + sameTypeCandidates.sort(key=lambda x: x[0], reverse=True) + chain = [c for _, c in sameTypeCandidates] + if connectionRoot is not None: + chain.append(connectionRoot) + return chain + + +def _isAncestorPath(ancestor: str, descendant: str) -> bool: + """True iff `ancestor` is a strict path-prefix of `descendant`. + + '/' is ancestor of every non-root path. For non-root prefixes, the + descendant must continue with '/' so '/foo' isn't treated as ancestor of + '/foobar'. + """ + if ancestor == descendant: + return False + if ancestor == "/": + return descendant != "/" + return descendant.startswith(ancestor + "/") + + +def getEffectiveFlag( + rec: Dict[str, Any], + flag: str, + sameConnectionDs: Iterable[Dict[str, Any]], +) -> Any: + """Resolve the effective value of a flag via path-traversal. + + Order: own value (if explicit) → nearest ancestor with explicit value → + static default (`False` or `'personal'`). + """ + if flag not in _INHERITABLE_FLAGS: + raise ValueError(f"Unknown inheritable flag: {flag}") + own = _getRecordValue(rec, flag) + if _isExplicit(own): + return own + chain = _findAncestorChain(rec, sameConnectionDs) + for ancestor in chain: + ancestorVal = _getRecordValue(ancestor, flag) + if _isExplicit(ancestorVal): + return ancestorVal + return _flagDefault(flag) + + +def cascadeResetDescendants( + rootIf: Any, + parentRec: Dict[str, Any], + flag: str, +) -> int: + """Reset all explicit descendant values of `flag` to NULL (= inherit). + + Descendant relation mirrors `_findAncestorChain`: + - Connection-root (`path='/'` AND `sourceType` ∈ authorities) is parent + of every other DS in that connection (cross-sourceType cascade). + - Otherwise: same-sourceType strict path-descendants only. + + Only the targeted `flag` is reset; other flags on the descendant are + untouched. + + Returns the number of records updated. + """ + if flag not in _INHERITABLE_FLAGS: + raise ValueError(f"Unknown inheritable flag: {flag}") + from modules.datamodels.datamodelDataSource import DataSource + + connectionId = _getRecordValue(parentRec, "connectionId") + parentSourceType = _getRecordValue(parentRec, "sourceType") + parentPath = _normalisePath(_getRecordValue(parentRec, "path")) + parentId = _getRecordValue(parentRec, "id") + if not connectionId or not parentSourceType: + return 0 + + parentIsConnectionRoot = ( + parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/" + ) + + siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + affected = 0 + for sib in siblings: + sibId = _getRecordValue(sib, "id") + if sibId == parentId: + continue + sibSourceType = _getRecordValue(sib, "sourceType") + sibPath = _normalisePath(_getRecordValue(sib, "path")) + if parentIsConnectionRoot: + # Connection-root resets everything else under this connection. + pass + else: + if sibSourceType != parentSourceType: + continue + if not _isAncestorPath(parentPath, sibPath): + continue + sibVal = _getRecordValue(sib, flag) + if not _isExplicit(sibVal): + continue + try: + rootIf.db.recordModify(DataSource, sibId, {flag: None}) + affected += 1 + except Exception as exc: + logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc) + if affected: + logger.info( + "Cascade-reset %s on %d descendants of DataSource (connectionId=%s, sourceType=%s, path=%s, connectionRoot=%s)", + flag, affected, connectionId, parentSourceType, parentPath, parentIsConnectionRoot, + ) + return affected + + +def _fdsClassify(fds: Dict[str, Any]) -> str: + """Return 'workspace' | 'table' | 'record' based on the FDS identifier shape.""" + tableName = _getRecordValue(fds, "tableName") or "" + recordFilter = _getRecordValue(fds, "recordFilter") + if tableName == "*": + return "workspace" + if not recordFilter: + return "table" + return "record" + + +def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool: + """Return True iff `parent` FDS is a strict ancestor of `child` FDS. + + Hierarchy within one `workspaceInstanceId`: + workspace-wildcard (tableName='*') → table-wildcard (tableName='X', !recordFilter) + → record-fds (tableName='X', recordFilter.id=...) + table-wildcard (tableName='X') → record-fds (tableName='X', recordFilter.id=...) + """ + parentWsId = _getRecordValue(parent, "workspaceInstanceId") + childWsId = _getRecordValue(child, "workspaceInstanceId") + if not parentWsId or parentWsId != childWsId: + return False + if _getRecordValue(parent, "id") == _getRecordValue(child, "id"): + return False + parentKind = _fdsClassify(parent) + childKind = _fdsClassify(child) + if parentKind == "workspace": + return childKind in ("table", "record") + if parentKind == "table": + if childKind != "record": + return False + return _getRecordValue(parent, "tableName") == _getRecordValue(child, "tableName") + return False + + +def getEffectiveFlagFds( + rec: Dict[str, Any], + flag: str, + sameWorkspaceFds: Iterable[Dict[str, Any]], +) -> Any: + """Resolve effective value of a FeatureDataSource flag. + + Order: own (if explicit) → table-wildcard (if explicit) → + workspace-wildcard (if explicit) → static default. + """ + if flag not in ("neutralize", "scope"): + raise ValueError(f"Unknown inheritable FDS flag: {flag}") + own = _getRecordValue(rec, flag) + if _isExplicit(own): + return own + workspaceFds: List[Dict[str, Any]] = list(sameWorkspaceFds) + ancestors = [a for a in workspaceFds if _fdsIsAncestor(a, rec)] + ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) + for ancestor in ancestors: + val = _getRecordValue(ancestor, flag) + if _isExplicit(val): + return val + return _flagDefault(flag) + + +def cascadeResetDescendantsFds( + rootIf: Any, + parentRec: Dict[str, Any], + flag: str, +) -> int: + """Reset explicit `flag` to NULL on every descendant FDS of `parentRec`. + + Only the targeted flag is reset; other flags on descendants are untouched. + Returns the number of records updated. + """ + if flag not in ("neutralize", "scope"): + raise ValueError(f"Unknown inheritable FDS flag: {flag}") + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + workspaceInstanceId = _getRecordValue(parentRec, "workspaceInstanceId") + if not workspaceInstanceId: + return 0 + siblings = rootIf.db.getRecordset( + FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId} + ) + affected = 0 + for sib in siblings: + if not _fdsIsAncestor(parentRec, sib): + continue + sibVal = _getRecordValue(sib, flag) + if not _isExplicit(sibVal): + continue + sibId = _getRecordValue(sib, "id") + try: + rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None}) + affected += 1 + except Exception as exc: + logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc) + if affected: + logger.info( + "FDS cascade-reset %s on %d descendants of FDS (workspaceInstanceId=%s, kind=%s)", + flag, affected, workspaceInstanceId, _fdsClassify(parentRec), + ) + return affected + + +def buildEffectiveByConnection( + dataSources: Iterable[Dict[str, Any]], + flag: str, +) -> Dict[str, Any]: + """Pre-compute the effective value of `flag` for every DataSource id. + + Useful for batch operations (walker, route DTOs) that touch many records + at once. O(N²) in the worst case but N is bounded per connection. + """ + if flag not in _INHERITABLE_FLAGS: + raise ValueError(f"Unknown inheritable flag: {flag}") + bySourceType: Dict[Tuple[str, str], List[Dict[str, Any]]] = {} + for ds in dataSources: + connId = _getRecordValue(ds, "connectionId") or "" + srcType = _getRecordValue(ds, "sourceType") or "" + bySourceType.setdefault((connId, srcType), []).append(ds) + + out: Dict[str, Any] = {} + for group in bySourceType.values(): + for rec in group: + recId = _getRecordValue(rec, "id") + out[recId] = getEffectiveFlag(rec, flag, group) + return out diff --git a/modules/serviceCenter/services/serviceKnowledge/_progressMessages.py b/modules/serviceCenter/services/serviceKnowledge/_progressMessages.py new file mode 100644 index 00000000..99d91d6b --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/_progressMessages.py @@ -0,0 +1,23 @@ +"""Central i18n registration for BackgroundJob progress messages. + +Walkers and consumers report progress via ``progressCb(..., messageKey="…", +messageParams={...})``. Those keys are not seen by ``t()`` at call time, so +without a stub registration they would never make it into the boot-time +``UiLanguageSet(xx)`` sync. Importing this module is enough to register +every known key — call sites stay clean while translators can still find +the texts in the standard i18n table. + +Keep this list in lockstep with the ``messageKey=`` arguments used in +``subConnectorSync*.py`` and ``subConnectorIngestConsumer.py``. +""" + +from modules.shared.i18nRegistry import t + +# Bootstrap walkers (one per connector family) +t("{n} Dateien verarbeitet, {indexed} indexiert") +t("{n} Tasks verarbeitet, {indexed} indexiert") +t("{n} Mails verarbeitet, {indexed} indexiert") + +# Ingestion consumer hand-offs +t("Verbindung wird aufgebaut ({authority})") +t("Synchronisierung läuft...") diff --git a/modules/serviceCenter/services/serviceKnowledge/_ragLimits.py b/modules/serviceCenter/services/serviceKnowledge/_ragLimits.py new file mode 100644 index 00000000..de0a4886 --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/_ragLimits.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Centralized RAG bootstrap limits + DataSource-scoped resolution. + +The original walkers (SharePoint, kDrive, gDrive, ClickUp) each carried their +own module-level `MAX_*_DEFAULT` constants and silently stopped indexing once +they were exceeded. That made it impossible for a user with a 500 MB folder to +override the 200 MB cap without a code change. + +This module is the single source of truth for two things: + +1. The canonical default budget per source kind (`FILES_LIMITS_DEFAULT`, + `CLICKUP_LIMITS_DEFAULT`). Walkers fall back to these when a DataSource has + no `settings.ragLimits` yet. + +2. The pure read/lazy-fill helpers that walkers and the API use to merge a + DataSource's stored settings with the defaults. No override layers, no + resolver chain: what is in `DataSource.settings.ragLimits` is what the + walker uses. + +Lazy fill: the first time a DataSource is processed, the defaults are written +to its `settings.ragLimits` so the UI shows real values immediately, even if +the user has never opened the settings modal. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + + +logger = logging.getLogger(__name__) + + +FILES_LIMITS_DEFAULT: Dict[str, int] = { + "maxItems": 500, + "maxBytes": 200 * 1024 * 1024, + "maxFileSize": 25 * 1024 * 1024, + "maxDepth": 4, +} + + +CLICKUP_LIMITS_DEFAULT: Dict[str, int] = { + "maxTasks": 500, + "maxWorkspaces": 3, + "maxListsPerWorkspace": 20, +} + + +_LIMITS_BY_KIND: Dict[str, Dict[str, int]] = { + "files": FILES_LIMITS_DEFAULT, + "clickup": CLICKUP_LIMITS_DEFAULT, +} + + +def getDefaults(kind: str) -> Dict[str, int]: + """Return a fresh copy of the default budget for the given walker kind. + + `kind` is either "files" (Sharepoint, kDrive, gDrive) or "clickup". + Returning a copy lets callers mutate the result safely. + """ + defaults = _LIMITS_BY_KIND.get(kind) + if defaults is None: + raise ValueError(f"Unknown RAG limit kind: {kind!r}") + return dict(defaults) + + +def getStoredOverrides(dataSource: Optional[Dict[str, Any]], kind: str) -> Dict[str, int]: + """Return ONLY the limits explicitly set on `dataSource.settings.ragLimits`. + + Missing keys are NOT filled with defaults — that is the caller's job (so + a programmatically supplied `limits=` from a Caller still wins when the + DataSource has no override). Pure read, no DB writes. + """ + if not isinstance(dataSource, dict): + return {} + settings = dataSource.get("settings") or {} + if not isinstance(settings, dict): + return {} + stored = settings.get("ragLimits") + if not isinstance(stored, dict): + return {} + allowed = set(_LIMITS_BY_KIND.get(kind, {}).keys()) + out: Dict[str, int] = {} + for key, raw in stored.items(): + if key not in allowed or raw is None: + continue + try: + out[key] = int(raw) + except (TypeError, ValueError): + logger.warning( + "Ignoring non-int ragLimits[%s]=%r on DataSource %s", + key, raw, dataSource.get("id"), + ) + return out + + +def getRagLimits(dataSource: Optional[Dict[str, Any]], kind: str) -> Dict[str, int]: + """Effective RAG limits for the API/cost-estimate use-case. + + Stored overrides win over `getDefaults(kind)`. Walkers should NOT use this + function — they should pass their own caller-limits as the fallback so that + a runtime-supplied `limits=` parameter is honoured (see `getStoredOverrides`). + """ + base = getDefaults(kind) + base.update(getStoredOverrides(dataSource, kind)) + return base diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py index c86aed86..618a9965 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py @@ -141,18 +141,39 @@ _SOURCE_TYPE_MAP = { def _loadRagEnabledDataSources(connectionId: str, dataSourceIds: Optional[list] = None): - """Load DataSource rows with ragIndexEnabled=true for a connection. + """Load DataSource rows whose *effective* ragIndexEnabled is True. - If dataSourceIds is provided (mini-bootstrap), filter to only those IDs. + Cascade-inherit semantics: a DataSource with `ragIndexEnabled=None` + follows its nearest ancestor's value (path-traversal). Walker iterates + over all DataSources whose effective value resolves to True, including + inherited ones. + + Returned dicts carry **resolved** flags (`neutralize`, `scope`) so the + downstream walkers can keep reading `ds.get("neutralize")` directly + without having to know about the inheritance chain. + + If `dataSourceIds` is provided (mini-bootstrap), the explicit set is + intersected with the effective-true set. """ from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelDataSource import DataSource + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag rootIf = getRootInterface() allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + resolved = [] + for ds in allDs: + effRagIndex = getEffectiveFlag(ds, "ragIndexEnabled", allDs) + if effRagIndex is not True: + continue + dsCopy = dict(ds) if isinstance(ds, dict) else {**ds.__dict__} + dsCopy["neutralize"] = getEffectiveFlag(ds, "neutralize", allDs) + dsCopy["scope"] = getEffectiveFlag(ds, "scope", allDs) + dsCopy["ragIndexEnabled"] = True + resolved.append(dsCopy) if dataSourceIds: - return [ds for ds in allDs if ds.get("id") in dataSourceIds and ds.get("ragIndexEnabled")] - return [ds for ds in allDs if ds.get("ragIndexEnabled")] + resolved = [ds for ds in resolved if ds.get("id") in dataSourceIds] + return resolved async def _bootstrapJobHandler( @@ -167,7 +188,11 @@ async def _bootstrapJobHandler( if not connectionId: raise ValueError("connection.bootstrap requires payload.connectionId") - progressCb(5, f"resolving {authority} connection") + progressCb( + 5, + messageKey="Verbindung wird aufgebaut ({authority})", + messageParams={"authority": authority}, + ) # Defensive consent check try: @@ -225,7 +250,7 @@ async def _bootstrapJobHandler( bootstrapOutlook, ) - progressCb(0, "Synchronisierung läuft...") + progressCb(0, messageKey="Synchronisierung läuft...") spDs = _filterDs("sharepoint") olDs = _filterDs("outlook") async def _noopResult(): @@ -251,7 +276,7 @@ async def _bootstrapJobHandler( bootstrapGmail, ) - progressCb(0, "Synchronisierung läuft...") + progressCb(0, messageKey="Synchronisierung läuft...") gdDs = _filterDs("drive") gmDs = _filterDs("gmail") async def _noopResult(): @@ -274,7 +299,7 @@ async def _bootstrapJobHandler( bootstrapClickup, ) - progressCb(0, "Synchronisierung läuft...") + progressCb(0, messageKey="Synchronisierung läuft...") cuDs = _filterDs("clickup") cuResult = await bootstrapClickup(connectionId=connectionId, progressCb=progressCb, dataSources=cuDs) if cuDs else {"skipped": True, "reason": "no_datasources"} return { @@ -288,7 +313,7 @@ async def _bootstrapJobHandler( bootstrapKdrive, ) - progressCb(0, "Synchronisierung läuft...") + progressCb(0, messageKey="Synchronisierung läuft...") kdDs = _filterDs("kdrive") kdResult = await bootstrapKdrive(connectionId=connectionId, progressCb=progressCb, dataSources=kdDs) if kdDs else {"skipped": True, "reason": "no_datasources"} return { diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py index 959e42c9..28c24275 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py @@ -33,13 +33,21 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( logger = logging.getLogger(__name__) -MAX_TASKS_DEFAULT = 500 -MAX_WORKSPACES_DEFAULT = 3 -MAX_LISTS_PER_WORKSPACE_DEFAULT = 20 +from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper + +_CLICKUP_DEFAULTS = _ragLimitsHelper.CLICKUP_LIMITS_DEFAULT +MAX_TASKS_DEFAULT = _CLICKUP_DEFAULTS["maxTasks"] +MAX_WORKSPACES_DEFAULT = _CLICKUP_DEFAULTS["maxWorkspaces"] +MAX_LISTS_PER_WORKSPACE_DEFAULT = _CLICKUP_DEFAULTS["maxListsPerWorkspace"] MAX_DESCRIPTION_CHARS_DEFAULT = 8000 MAX_AGE_DAYS_DEFAULT = 180 +def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]: + """Return explicit RAG-limit overrides stored on the DataSource (or {}).""" + return _ragLimitsHelper.getStoredOverrides(ds, "clickup") + + @dataclass class ClickupBootstrapLimits: maxTasks: int = MAX_TASKS_DEFAULT @@ -236,10 +244,11 @@ async def bootstrapClickup( dsId = ds.get("id", "") dsNeutralize = ds.get("neutralize", False) + eff = _resolveDataSourceLimits(dsId, ds) dsLimits = ClickupBootstrapLimits( - maxTasks=limits.maxTasks, - maxWorkspaces=limits.maxWorkspaces, - maxListsPerWorkspace=limits.maxListsPerWorkspace, + maxTasks=eff.get("maxTasks", limits.maxTasks), + maxWorkspaces=eff.get("maxWorkspaces", limits.maxWorkspaces), + maxListsPerWorkspace=eff.get("maxListsPerWorkspace", limits.maxListsPerWorkspace), maxDescriptionChars=limits.maxDescriptionChars, maxAgeDays=limits.maxAgeDays, includeClosed=limits.includeClosed, @@ -520,7 +529,11 @@ async def _ingestTask( if hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): return try: - progressCb(0, f"{processed} Tasks verarbeitet, {result.indexed} indexiert") + progressCb( + 0, + messageKey="{n} Tasks verarbeitet, {indexed} indexiert", + messageParams={"n": processed, "indexed": result.indexed}, + ) except Exception: pass if processed % 50 == 0: diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py index e27abacb..7600cce0 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py @@ -31,13 +31,21 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( logger = logging.getLogger(__name__) -MAX_ITEMS_DEFAULT = 500 -MAX_BYTES_DEFAULT = 200 * 1024 * 1024 -MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024 +from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper + +_FILES_DEFAULTS = _ragLimitsHelper.FILES_LIMITS_DEFAULT +MAX_ITEMS_DEFAULT = _FILES_DEFAULTS["maxItems"] +MAX_BYTES_DEFAULT = _FILES_DEFAULTS["maxBytes"] +MAX_FILE_SIZE_DEFAULT = _FILES_DEFAULTS["maxFileSize"] +MAX_DEPTH_DEFAULT = _FILES_DEFAULTS["maxDepth"] SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/") -MAX_DEPTH_DEFAULT = 4 MAX_AGE_DAYS_DEFAULT = 365 + +def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]: + """Return explicit RAG-limit overrides stored on the DataSource (or {}).""" + return _ragLimitsHelper.getStoredOverrides(ds, "files") + FOLDER_MIME = "application/vnd.google-apps.folder" @@ -175,12 +183,13 @@ async def bootstrapGdrive( dsId = ds.get("id", "") dsNeutralize = ds.get("neutralize", False) dsMaxAgeDays = ds.get("maxAgeDays", limits.maxAgeDays) + eff = _resolveDataSourceLimits(dsId, ds) dsLimits = GdriveBootstrapLimits( - maxItems=limits.maxItems, - maxBytes=limits.maxBytes, - maxFileSize=limits.maxFileSize, + maxItems=eff.get("maxItems", limits.maxItems), + maxBytes=eff.get("maxBytes", limits.maxBytes), + maxFileSize=eff.get("maxFileSize", limits.maxFileSize), skipMimePrefixes=limits.skipMimePrefixes, - maxDepth=limits.maxDepth, + maxDepth=eff.get("maxDepth", limits.maxDepth), maxAgeDays=dsMaxAgeDays, neutralize=dsNeutralize, ) @@ -459,7 +468,11 @@ async def _ingestOne( processed = result.indexed + result.skippedDuplicate if progressCb is not None and processed % 5 == 0: try: - progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert") + progressCb( + 0, + messageKey="{n} Dateien verarbeitet, {indexed} indexiert", + messageParams={"n": processed, "indexed": result.indexed}, + ) except Exception: pass logger.info( diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py index 3130e942..96f9cecf 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py @@ -474,7 +474,11 @@ async def _ingestMessage( processed = result.indexed + result.skippedDuplicate if progressCb is not None and processed % 5 == 0: try: - progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert") + progressCb( + 0, + messageKey="{n} Mails verarbeitet, {indexed} indexiert", + messageParams={"n": processed, "indexed": result.indexed}, + ) except Exception: pass if processed % 50 == 0: diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py index dcf19e39..f95aafd1 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py @@ -27,11 +27,19 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( logger = logging.getLogger(__name__) -MAX_ITEMS_DEFAULT = 500 -MAX_BYTES_DEFAULT = 200 * 1024 * 1024 -MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024 +from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper + +_FILES_DEFAULTS = _ragLimitsHelper.FILES_LIMITS_DEFAULT +MAX_ITEMS_DEFAULT = _FILES_DEFAULTS["maxItems"] +MAX_BYTES_DEFAULT = _FILES_DEFAULTS["maxBytes"] +MAX_FILE_SIZE_DEFAULT = _FILES_DEFAULTS["maxFileSize"] +MAX_DEPTH_DEFAULT = _FILES_DEFAULTS["maxDepth"] SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/") -MAX_DEPTH_DEFAULT = 4 + + +def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]: + """Return explicit RAG-limit overrides stored on the DataSource (or {}).""" + return _ragLimitsHelper.getStoredOverrides(ds, "files") @dataclass @@ -143,12 +151,13 @@ async def bootstrapKdrive( dsPath = ds.get("path", "") dsId = ds.get("id", "") dsNeutralize = ds.get("neutralize", False) + eff = _resolveDataSourceLimits(dsId, ds) dsLimits = KdriveBootstrapLimits( - maxItems=limits.maxItems, - maxBytes=limits.maxBytes, - maxFileSize=limits.maxFileSize, + maxItems=eff.get("maxItems", limits.maxItems), + maxBytes=eff.get("maxBytes", limits.maxBytes), + maxFileSize=eff.get("maxFileSize", limits.maxFileSize), skipMimePrefixes=limits.skipMimePrefixes, - maxDepth=limits.maxDepth, + maxDepth=eff.get("maxDepth", limits.maxDepth), neutralize=dsNeutralize, ) @@ -416,7 +425,11 @@ async def _ingestOne( processed = result.indexed + result.skippedDuplicate if progressCb is not None and processed % 5 == 0: try: - progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert") + progressCb( + 0, + messageKey="{n} Dateien verarbeitet, {indexed} indexiert", + messageParams={"n": processed, "indexed": result.indexed}, + ) except Exception: pass diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py index 17220d97..e676b156 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py @@ -460,7 +460,11 @@ async def _ingestMessage( processed = result.indexed + result.skippedDuplicate if progressCb is not None and processed % 5 == 0: try: - progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert") + progressCb( + 0, + messageKey="{n} Mails verarbeitet, {indexed} indexiert", + messageParams={"n": processed, "indexed": result.indexed}, + ) except Exception: pass if processed % 50 == 0: diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py index e06fd36b..87c4c92a 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py @@ -30,14 +30,27 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( logger = logging.getLogger(__name__) -MAX_ITEMS_DEFAULT = 500 -MAX_BYTES_DEFAULT = 200 * 1024 * 1024 -MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024 +from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper + +_FILES_DEFAULTS = _ragLimitsHelper.FILES_LIMITS_DEFAULT +MAX_ITEMS_DEFAULT = _FILES_DEFAULTS["maxItems"] +MAX_BYTES_DEFAULT = _FILES_DEFAULTS["maxBytes"] +MAX_FILE_SIZE_DEFAULT = _FILES_DEFAULTS["maxFileSize"] +MAX_DEPTH_DEFAULT = _FILES_DEFAULTS["maxDepth"] SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/") -MAX_DEPTH_DEFAULT = 4 MAX_SITES_DEFAULT = 3 +def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]: + """Return explicit RAG-limit overrides stored on the DataSource. + + Empty dict means "use caller-supplied limits" — never overrides them with + defaults. Used to merge per-DataSource user settings on top of the + walker's runtime limits. + """ + return _ragLimitsHelper.getStoredOverrides(ds, "files") + + @dataclass class SharepointBootstrapLimits: maxItems: int = MAX_ITEMS_DEFAULT @@ -165,12 +178,13 @@ async def bootstrapSharepoint( dsPath = ds.get("path", "") dsId = ds.get("id", "") dsNeutralize = ds.get("neutralize", False) + eff = _resolveDataSourceLimits(dsId, ds) dsLimits = SharepointBootstrapLimits( - maxItems=limits.maxItems, - maxBytes=limits.maxBytes, - maxFileSize=limits.maxFileSize, + maxItems=eff.get("maxItems", limits.maxItems), + maxBytes=eff.get("maxBytes", limits.maxBytes), + maxFileSize=eff.get("maxFileSize", limits.maxFileSize), skipMimePrefixes=limits.skipMimePrefixes, - maxDepth=limits.maxDepth, + maxDepth=eff.get("maxDepth", limits.maxDepth), maxSites=limits.maxSites, neutralize=dsNeutralize, ) @@ -441,7 +455,11 @@ async def _ingestOne( processed = result.indexed + result.skippedDuplicate if progressCb is not None and processed % 5 == 0: try: - progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert") + progressCb( + 0, + messageKey="{n} Dateien verarbeitet, {indexed} indexiert", + messageParams={"n": processed, "indexed": result.indexed}, + ) except Exception: pass if processed % 50 == 0: diff --git a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py b/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py index 10be150d..0deae777 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py +++ b/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py @@ -1,78 +1,32 @@ # Copyright (c) 2025 Patrick Motsch # All rights reserved. -"""Resolve effective policies (neutralize, ragIndexEnabled) for DataSource tree hierarchies. +"""DEPRECATED: Use `_inheritFlags.getEffectiveFlag()` directly. -Tree-inheritance rule: nearest ancestor DataSource with an explicit value wins. -If no ancestor has a value, the default (False) is used. +Thin shim to the new cascade-inherit helper. Kept so external callers don't +break on import — internal walkers consume pre-resolved dicts via +`_loadRagEnabledDataSources`. """ from __future__ import annotations -import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -logger = logging.getLogger(__name__) +from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag def resolveEffectiveNeutralize( ds: Dict[str, Any], allDataSources: List[Dict[str, Any]], ) -> bool: - """Compute effective neutralize by walking up the path tree. - - A DataSource at /sites/HR/Documents inherits from /sites/HR if - that ancestor has neutralize=True and the child has no explicit override. - """ - ownValue = ds.get("neutralize") - if ownValue is not None and ownValue is not False: - return True - if ownValue is False: - return False - return _findAncestorPolicy(ds, allDataSources, "neutralize") + """DEPRECATED: use `getEffectiveFlag(ds, 'neutralize', allDataSources)`.""" + value = getEffectiveFlag(ds, "neutralize", allDataSources) + return bool(value) def resolveEffectiveRagIndexEnabled( ds: Dict[str, Any], allDataSources: List[Dict[str, Any]], ) -> bool: - """Compute effective ragIndexEnabled by walking up the path tree.""" - ownValue = ds.get("ragIndexEnabled") - if ownValue is True: - return True - if ownValue is False: - return False - return _findAncestorPolicy(ds, allDataSources, "ragIndexEnabled") - - -def _findAncestorPolicy( - ds: Dict[str, Any], - allDataSources: List[Dict[str, Any]], - field: str, -) -> bool: - """Walk ancestors (longest-prefix match) to find an inherited policy value.""" - dsPath = ds.get("path", "") - connectionId = ds.get("connectionId", "") - if not dsPath: - return False - - ancestors = [] - for candidate in allDataSources: - if candidate.get("id") == ds.get("id"): - continue - if candidate.get("connectionId") != connectionId: - continue - candidatePath = candidate.get("path", "") - if not candidatePath: - continue - if dsPath.startswith(candidatePath) and len(candidatePath) < len(dsPath): - ancestors.append(candidate) - - ancestors.sort(key=lambda a: len(a.get("path", "")), reverse=True) - - for ancestor in ancestors: - val = ancestor.get(field) - if val is True: - return True - if val is False: - return False - return False + """DEPRECATED: use `getEffectiveFlag(ds, 'ragIndexEnabled', allDataSources)`.""" + value = getEffectiveFlag(ds, "ragIndexEnabled", allDataSources) + return bool(value) diff --git a/modules/shared/i18nRegistry.py b/modules/shared/i18nRegistry.py index 7e620f8d..06ccb20e 100644 --- a/modules/shared/i18nRegistry.py +++ b/modules/shared/i18nRegistry.py @@ -124,6 +124,48 @@ def t(key: str, context: str = "api", value: str = "") -> str: return _CACHE.get(lang, {}).get(key, f"[{key}]") +def resolveJobMessage(messageData: Optional[Dict[str, Any]], lang: Optional[str] = None) -> Optional[str]: + """Translate a structured BackgroundJob progress payload. + + ``messageData`` shape (written by ``JobProgressCallback`` when callers + pass ``messageKey`` / ``messageParams``):: + + {"key": "{n} Dateien verarbeitet, {indexed} indexiert", + "params": {"n": 145, "indexed": 106}} + + The walker call sites use a string-literal ``messageKey=``; the matching + ``t("…")`` literal lives in the feature's progress-key registration + module (e.g. ``serviceKnowledge/_progressMessages.py``, + ``features/trustee/mainTrustee.py``) so the boot sync picks it up. + + This helper is the **server-side** translation hop so route handlers can + deliver a fully rendered ``progressMessage`` string to the frontend -- + the frontend never calls ``t()`` on backend-supplied keys. + """ + if not messageData or not isinstance(messageData, dict): + return None + key = messageData.get("key") + if not isinstance(key, str) or not key: + return None + params = messageData.get("params") or {} + + if lang is not None: + token = _CURRENT_LANGUAGE.set(lang) + try: + template = t(key) + finally: + _CURRENT_LANGUAGE.reset(token) + else: + template = t(key) + + if isinstance(params, dict) and params: + try: + return template.format(**params) + except (KeyError, IndexError, ValueError): + return template + return template + + def resolveText(value: Any, lang: Optional[str] = None) -> str: """Resolve any value to a translated string for the current request language. diff --git a/scripts/debug_rag_job_result.py b/scripts/debug_rag_job_result.py new file mode 100644 index 00000000..c107f21e --- /dev/null +++ b/scripts/debug_rag_job_result.py @@ -0,0 +1,70 @@ +"""Diagnose: read a connection.bootstrap job result and print its keys. + +Usage (from repo root): + python gateway\scripts\debug_rag_job_result.py + +Prints the most recent SUCCESS connection.bootstrap job per UserConnection so +we can see whether the `stoppedAtLimit` key actually landed in the JSONB +`result` column. If it is missing here, the bug is in the writer (handler or +_markSuccess); if it is present here but absent in the HTTP response, the bug +is in routeRagInventory. +""" +from __future__ import annotations + +import os +import sys +import json +from pathlib import Path + +_HERE = Path(__file__).resolve() +sys.path.insert(0, str(_HERE.parent.parent)) # gateway/ +os.chdir(_HERE.parent.parent) + +from modules.shared.configuration import APP_CONFIG # noqa: E402 +from modules.connectors.connectorDbPostgre import getCachedConnector # noqa: E402 +from modules.datamodels.datamodelBackgroundJob import BackgroundJob # noqa: E402 +from modules.routes.routeRagInventory import _flattenJobResult # noqa: E402 + + +def _main() -> None: + db = getCachedConnector( + dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"), + dbHost=APP_CONFIG.get("DB_HOST", "localhost"), + dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), + dbUser=APP_CONFIG.get("DB_USER"), + dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), + ) + + rows = db.getRecordset(BackgroundJob) + rows = [r for r in rows if r.get("jobType") == "connection.bootstrap"] + rows = [r for r in rows if r.get("status") == "SUCCESS"] + rows.sort(key=lambda r: r.get("createdAt") or 0, reverse=True) + + if not rows: + print("No SUCCESS connection.bootstrap jobs found.") + return + + seenConnections: set[str] = set() + for j in rows: + connId = (j.get("payload") or {}).get("connectionId", "") + if connId in seenConnections: + continue + seenConnections.add(connId) + result = j.get("result") or {} + flat = _flattenJobResult(result) if isinstance(result, dict) else {} + print("=" * 80) + print(f"jobId = {j.get('id')}") + print(f"connectionId = {connId}") + print(f"finishedAt = {j.get('finishedAt')}") + print(f"raw keys = {sorted(result.keys()) if isinstance(result, dict) else 'N/A'}") + print("--- flattened (what the API will return now) ---") + print(f" indexed = {flat.get('indexed')}") + print(f" skippedDuplicate= {flat.get('skippedDuplicate')}") + print(f" skippedPolicy = {flat.get('skippedPolicy')}") + print(f" stoppedAtLimit = {flat.get('stoppedAtLimit')!r} <-- KEY CHECK") + print(f" limits = {flat.get('limits')}") + print(f" bytesProcessed = {flat.get('bytesProcessed')}") + + +if __name__ == "__main__": + _main() diff --git a/scripts/script_db_migrate_backgroundjob_progress_data.py b/scripts/script_db_migrate_backgroundjob_progress_data.py new file mode 100644 index 00000000..bc5fc348 --- /dev/null +++ b/scripts/script_db_migrate_backgroundjob_progress_data.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""Migration: Add `progressMessageData` JSONB column to BackgroundJob. + +Carries the structured i18n payload that lets the frontend translate +walker progress messages (e.g. "{n} Dateien verarbeitet, {indexed} +indexiert") into the user's UI language. `progressMessage` stays around +as the rendered fallback for older clients and audit logs. + +Safe to run multiple times (checks column existence before acting). + +Usage: + python scripts/script_db_migrate_backgroundjob_progress_data.py [--dry-run] +""" + +import os +import sys +import argparse +import logging +from pathlib import Path + +scriptPath = Path(__file__).resolve() +gatewayPath = scriptPath.parent.parent +sys.path.insert(0, str(gatewayPath)) +os.chdir(str(gatewayPath)) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + +import psycopg2 +from modules.shared.configuration import APP_CONFIG + + +def _getConnection(): + return psycopg2.connect( + host=APP_CONFIG.get("DB_HOST", "localhost"), + port=int(APP_CONFIG.get("DB_PORT", "5432")), + database=APP_CONFIG.get("DB_DATABASE", "poweron_app"), + user=APP_CONFIG.get("DB_USER"), + password=APP_CONFIG.get("DB_PASSWORD_SECRET"), + ) + + +def _columnExists(cur, table: str, column: str) -> bool: + cur.execute( + """SELECT 1 FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = %s AND column_name = %s""", + (table, column), + ) + return cur.fetchone() is not None + + +def _tableExists(cur, table: str) -> bool: + cur.execute( + """SELECT 1 FROM information_schema.tables + WHERE table_schema = 'public' AND table_name = %s""", + (table,), + ) + return cur.fetchone() is not None + + +def migrate(dryRun: bool = False): + conn = _getConnection() + conn.autocommit = False + cur = conn.cursor() + + table, column = "BackgroundJob", "progressMessageData" + executed = [] + + if not _tableExists(cur, table): + logger.warning("SKIP: table %s does not exist yet (will be created on next ORM init)", table) + elif _columnExists(cur, table, column): + logger.info("SKIP: %s.%s already exists", table, column) + else: + sql = f'ALTER TABLE public."{table}" ADD COLUMN "{column}" JSONB DEFAULT NULL;' + logger.info("EXEC: %s", sql) + if not dryRun: + cur.execute(sql) + executed.append(sql) + + if not dryRun and executed: + conn.commit() + logger.info("Migration committed (%d statements)", len(executed)) + elif dryRun and executed: + conn.rollback() + logger.info("DRY RUN -- would execute %d statements", len(executed)) + else: + logger.info("Nothing to do -- schema already up to date") + + cur.close() + conn.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dry-run", action="store_true", help="Print SQL without executing") + args = parser.parse_args() + migrate(dryRun=args.dry_run) diff --git a/scripts/script_db_migrate_datasource_inherit.py b/scripts/script_db_migrate_datasource_inherit.py new file mode 100644 index 00000000..3444cbee --- /dev/null +++ b/scripts/script_db_migrate_datasource_inherit.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +"""Migration: Drop NOT NULL on DataSource/FeatureDataSource cascade-inherit flags. + +Switches three-valued semantics (NULL = inherit, True/False = explicit) for: + - DataSource.neutralize, ragIndexEnabled, scope + - FeatureDataSource.neutralize, scope + +Existing rows keep their explicit values; only new records (or explicit reset +via cascade) start with NULL. Migration is non-destructive and idempotent. + +Safe to run multiple times. + +Usage: + python scripts/script_db_migrate_datasource_inherit.py [--dry-run] +""" + +import os +import sys +import argparse +import logging +from pathlib import Path + +scriptPath = Path(__file__).resolve() +gatewayPath = scriptPath.parent.parent +sys.path.insert(0, str(gatewayPath)) +os.chdir(str(gatewayPath)) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + +import psycopg2 +from modules.shared.configuration import APP_CONFIG + + +def _getConnection(): + return psycopg2.connect( + host=APP_CONFIG.get("DB_HOST", "localhost"), + port=int(APP_CONFIG.get("DB_PORT", "5432")), + database=APP_CONFIG.get("DB_DATABASE", "poweron_app"), + user=APP_CONFIG.get("DB_USER"), + password=APP_CONFIG.get("DB_PASSWORD_SECRET"), + ) + + +def _tableExists(cur, table: str) -> bool: + cur.execute( + """SELECT 1 FROM information_schema.tables + WHERE table_schema = 'public' AND table_name = %s""", + (table,), + ) + return cur.fetchone() is not None + + +def _columnIsNullable(cur, table: str, column: str) -> bool: + cur.execute( + """SELECT is_nullable FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = %s AND column_name = %s""", + (table, column), + ) + row = cur.fetchone() + if not row: + return False + return row[0] == "YES" + + +def migrate(dryRun: bool = False): + conn = _getConnection() + conn.autocommit = False + cur = conn.cursor() + + targets = [ + ("DataSource", "neutralize"), + ("DataSource", "ragIndexEnabled"), + ("DataSource", "scope"), + ("FeatureDataSource", "neutralize"), + ("FeatureDataSource", "scope"), + ] + + executed = [] + for table, column in targets: + if not _tableExists(cur, table): + logger.warning("SKIP: table %s does not exist yet", table) + continue + if _columnIsNullable(cur, table, column): + logger.info("SKIP: %s.%s already nullable", table, column) + continue + sql = f'ALTER TABLE public."{table}" ALTER COLUMN "{column}" DROP NOT NULL;' + logger.info("EXEC: %s", sql) + if not dryRun: + cur.execute(sql) + executed.append(sql) + + if not dryRun and executed: + conn.commit() + logger.info("Migration committed (%d statements)", len(executed)) + elif dryRun and executed: + conn.rollback() + logger.info("DRY RUN -- would execute %d statements", len(executed)) + else: + logger.info("Nothing to do -- schema already nullable") + + cur.close() + conn.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dry-run", action="store_true", help="Print SQL without executing") + args = parser.parse_args() + migrate(dryRun=args.dry_run) diff --git a/scripts/script_db_migrate_datasource_settings.py b/scripts/script_db_migrate_datasource_settings.py new file mode 100644 index 00000000..9e821221 --- /dev/null +++ b/scripts/script_db_migrate_datasource_settings.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +"""Migration: Add `settings` JSONB column to DataSource and FeatureDataSource. + +This is a one-off migration for the UDB DataSource Settings (Settings-Icon) +feature: walkers read RAG limits (maxBytes, maxFileSize, maxItems, maxDepth) +from this JSON blob, the UI edits them. Existing rows get NULL until the +next bootstrap lazy-fills sensible defaults from `_ragLimits.RAG_LIMITS_DEFAULT`. + +Safe to run multiple times (checks column existence before acting). + +Usage: + python scripts/script_db_migrate_datasource_settings.py [--dry-run] +""" + +import os +import sys +import argparse +import logging +from pathlib import Path + +scriptPath = Path(__file__).resolve() +gatewayPath = scriptPath.parent.parent +sys.path.insert(0, str(gatewayPath)) +os.chdir(str(gatewayPath)) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + +import psycopg2 +from modules.shared.configuration import APP_CONFIG + + +def _getConnection(): + return psycopg2.connect( + host=APP_CONFIG.get("DB_HOST", "localhost"), + port=int(APP_CONFIG.get("DB_PORT", "5432")), + database=APP_CONFIG.get("DB_DATABASE", "poweron_app"), + user=APP_CONFIG.get("DB_USER"), + password=APP_CONFIG.get("DB_PASSWORD_SECRET"), + ) + + +def _columnExists(cur, table: str, column: str) -> bool: + cur.execute( + """SELECT 1 FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = %s AND column_name = %s""", + (table, column), + ) + return cur.fetchone() is not None + + +def _tableExists(cur, table: str) -> bool: + cur.execute( + """SELECT 1 FROM information_schema.tables + WHERE table_schema = 'public' AND table_name = %s""", + (table,), + ) + return cur.fetchone() is not None + + +def migrate(dryRun: bool = False): + conn = _getConnection() + conn.autocommit = False + cur = conn.cursor() + + targets = [ + ("DataSource", "settings"), + ("FeatureDataSource", "settings"), + ] + + executed = [] + for table, column in targets: + if not _tableExists(cur, table): + logger.warning("SKIP: table %s does not exist yet (will be created on next ORM init)", table) + continue + if _columnExists(cur, table, column): + logger.info("SKIP: %s.%s already exists", table, column) + continue + sql = f'ALTER TABLE public."{table}" ADD COLUMN "{column}" JSONB DEFAULT NULL;' + logger.info("EXEC: %s", sql) + if not dryRun: + cur.execute(sql) + executed.append(sql) + + if not dryRun and executed: + conn.commit() + logger.info("Migration committed (%d statements)", len(executed)) + elif dryRun and executed: + conn.rollback() + logger.info("DRY RUN -- would execute %d statements", len(executed)) + else: + logger.info("Nothing to do -- schema already up to date") + + cur.close() + conn.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dry-run", action="store_true", help="Print SQL without executing") + args = parser.parse_args() + migrate(dryRun=args.dry_run) diff --git a/tests/unit/services/test_costEstimate.py b/tests/unit/services/test_costEstimate.py new file mode 100644 index 00000000..e49aca6a --- /dev/null +++ b/tests/unit/services/test_costEstimate.py @@ -0,0 +1,55 @@ +"""Unit tests for `_costEstimate` heuristic. + +Validates the output shape, basic formulas, and that 'basis' annotations +are always present (the user-facing transparency contract). +""" + +from __future__ import annotations + +import unittest + +from modules.serviceCenter.services.serviceKnowledge import _costEstimate + + +class TestCostEstimate(unittest.TestCase): + def test_files_shape(self): + result = _costEstimate.estimateBootstrapCost( + {"maxBytes": 200 * 1024 * 1024}, kind="files", + ) + self.assertIn("estimatedTokens", result) + self.assertIn("estimatedUsd", result) + self.assertIn("basis", result) + self.assertIn("assumptions", result["basis"]) + self.assertIn("formula", result["basis"]["assumptions"]) + self.assertIn("notes", result["basis"]) + + def test_files_doubling_maxBytes_doubles_tokens(self): + low = _costEstimate.estimateBootstrapCost({"maxBytes": 100 * 1024 * 1024}, kind="files") + high = _costEstimate.estimateBootstrapCost({"maxBytes": 200 * 1024 * 1024}, kind="files") + self.assertEqual(high["estimatedTokens"], low["estimatedTokens"] * 2) + + def test_clickup_uses_tasks_and_workspaces(self): + result = _costEstimate.estimateBootstrapCost( + {"maxTasks": 100, "maxWorkspaces": 2, "maxListsPerWorkspace": 10}, + kind="clickup", + ) + expectedTokens = 100 * 2 * _costEstimate.DEFAULT_TOKENS_PER_ITEM + self.assertEqual(result["estimatedTokens"], expectedTokens) + + def test_unknown_kind_returns_zero(self): + result = _costEstimate.estimateBootstrapCost({}, kind="totally-unknown") + self.assertEqual(result["estimatedTokens"], 0) + self.assertEqual(result["estimatedUsd"], 0.0) + + def test_usd_is_rounded_4_decimals(self): + result = _costEstimate.estimateBootstrapCost({"maxBytes": 1024 * 1024}, kind="files") + rounded = round(result["estimatedUsd"], 4) + self.assertEqual(result["estimatedUsd"], rounded) + + def test_basis_includes_input_limits(self): + result = _costEstimate.estimateBootstrapCost({"maxBytes": 42}, kind="files") + self.assertEqual(result["basis"]["limits"]["maxBytes"], 42) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/services/test_inheritFlags.py b/tests/unit/services/test_inheritFlags.py new file mode 100644 index 00000000..b177e767 --- /dev/null +++ b/tests/unit/services/test_inheritFlags.py @@ -0,0 +1,330 @@ +"""Unit tests for `_inheritFlags` cascade-inherit helpers. + +Verifies: +- getEffectiveFlag walks ancestors via path-prefix matching +- root default is False (or 'personal' for scope) when nothing explicit in chain +- only same-connectionId AND same-sourceType ancestors are considered +- cascadeResetDescendants only touches descendants with explicit values for THAT flag +- '/' is treated as ancestor of every non-root path +- '/foo' is NOT ancestor of '/foobar' (must require '/' separator) +""" + +from __future__ import annotations + +import unittest +from typing import List +from unittest.mock import MagicMock + +from modules.serviceCenter.services.serviceKnowledge import _inheritFlags + + +def _ds(idVal: str, path: str, **flags) -> dict: + """Build a DataSource dict with sensible defaults for a fixture.""" + base = { + "id": idVal, + "connectionId": "conn-1", + "sourceType": "sharepointFolder", + "path": path, + "neutralize": None, + "ragIndexEnabled": None, + "scope": None, + } + base.update(flags) + return base + + +class TestEffectiveFlag(unittest.TestCase): + def test_explicit_own_value_wins(self): + root = _ds("r", "/", neutralize=False) + leaf = _ds("l", "/folder/sub", neutralize=True) + self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf])) + + def test_inherits_from_root_when_own_is_none(self): + root = _ds("r", "/", neutralize=True) + leaf = _ds("l", "/folder/sub") + self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf])) + + def test_default_false_when_chain_empty(self): + leaf = _ds("l", "/folder/sub") + self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [leaf])) + + def test_nearest_ancestor_wins_over_distant(self): + root = _ds("r", "/", neutralize=False) + mid = _ds("m", "/folder", neutralize=True) + leaf = _ds("l", "/folder/sub") + self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, mid, leaf])) + + def test_different_connection_ignored(self): + otherConn = _ds("o", "/", connectionId="conn-2", neutralize=True) + leaf = _ds("l", "/folder") + self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherConn, leaf])) + + def test_different_sourcetype_ignored(self): + otherType = _ds("o", "/", sourceType="outlookFolder", neutralize=True) + leaf = _ds("l", "/folder") + self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherType, leaf])) + + def test_path_separator_required(self): + """`/foo` must NOT be ancestor of `/foobar` (no shared `/` boundary).""" + notAncestor = _ds("a", "/foo", neutralize=True) + leaf = _ds("l", "/foobar") + self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [notAncestor, leaf])) + + def test_root_is_ancestor_of_everything(self): + root = _ds("r", "/", neutralize=True) + leaf = _ds("l", "/anything/anywhere") + self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf])) + + def test_scope_inheritance_with_string_default(self): + root = _ds("r", "/", scope="mandate") + leaf = _ds("l", "/folder") + self.assertEqual(_inheritFlags.getEffectiveFlag(leaf, "scope", [root, leaf]), "mandate") + + def test_scope_default_personal_when_empty(self): + leaf = _ds("l", "/folder") + self.assertEqual(_inheritFlags.getEffectiveFlag(leaf, "scope", [leaf]), "personal") + + def test_unknown_flag_raises(self): + leaf = _ds("l", "/") + with self.assertRaises(ValueError): + _inheritFlags.getEffectiveFlag(leaf, "unknownFlag", [leaf]) + + def test_explicit_false_overrides_inherited_true(self): + """Explicit False on a child must NOT cascade up to True from an ancestor.""" + root = _ds("r", "/", neutralize=True) + leaf = _ds("l", "/folder", neutralize=False) + self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf])) + + def test_connection_root_inherits_cross_sourcetype(self): + """Connection-root (sourceType=authority, path='/') is ancestor of all DS in that connection.""" + connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) + spService = _ds("sp", "/", sourceType="sharepointFolder") + olService = _ds("ol", "/", sourceType="outlookFolder") + self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", [connRoot, spService, olService])) + self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", [connRoot, spService, olService])) + + def test_same_sourcetype_ancestor_wins_over_connection_root(self): + """A same-sourceType service-root ancestor beats the connection-root.""" + connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) + spRoot = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) + spLeaf = _ds("spl", "/sites/x", sourceType="sharepointFolder") + self.assertFalse(_inheritFlags.getEffectiveFlag(spLeaf, "neutralize", [connRoot, spRoot, spLeaf])) + + def test_connection_root_does_not_self_inherit(self): + """Connection-root has no ancestor — does not infinite-loop on itself.""" + connRoot = _ds("conn", "/", sourceType="msft") + self.assertFalse(_inheritFlags.getEffectiveFlag(connRoot, "neutralize", [connRoot])) + + +class TestCascadeReset(unittest.TestCase): + def _makeRootIf(self, dataSources: List[dict]): + rootIf = MagicMock() + rootIf.db.getRecordset = MagicMock(return_value=dataSources) + modified = [] + + def _modify(model, recordId, fields): + modified.append((recordId, fields)) + rootIf.db.recordModify = MagicMock(side_effect=_modify) + return rootIf, modified + + def test_resets_only_explicit_descendants(self): + parent = _ds("p", "/sites", neutralize=True) + explicitChild = _ds("c1", "/sites/folder1", neutralize=False) + inheritChild = _ds("c2", "/sites/folder2") # inherit -> not touched + sibling = _ds("s", "/other", neutralize=True) # NOT a descendant + rootIf, modified = self._makeRootIf([parent, explicitChild, inheritChild, sibling]) + + affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + + self.assertEqual(affected, 1) + self.assertEqual(modified, [("c1", {"neutralize": None})]) + + def test_does_not_touch_other_flags(self): + parent = _ds("p", "/sites", neutralize=True) + child = _ds("c", "/sites/sub", neutralize=False, ragIndexEnabled=True) + rootIf, modified = self._makeRootIf([parent, child]) + + _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + + self.assertEqual(modified, [("c", {"neutralize": None})]) + # ragIndexEnabled and scope on the child must remain untouched. + + def test_does_not_cross_sourcetype(self): + """Non-connection-root parents stay within their sourceType for cascade.""" + parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder") + otherTypeDescendant = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder") + rootIf, modified = self._makeRootIf([parent, otherTypeDescendant]) + + affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + + self.assertEqual(affected, 0) + self.assertEqual(modified, []) + + def test_connection_root_cascades_cross_sourcetype(self): + """Toggle on connection-root cascades into every explicit DS of that connection.""" + connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) + spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) + olInherit = _ds("ol", "/", sourceType="outlookFolder") + spLeafExplicit = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True) + rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeafExplicit]) + + affected = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize") + + # spExplicit and spLeafExplicit had explicit values → reset. olInherit untouched. + self.assertEqual(affected, 2) + self.assertEqual({m[0] for m in modified}, {"sp", "sp-leaf"}) + for _, fields in modified: + self.assertEqual(fields, {"neutralize": None}) + + def test_unknown_flag_raises(self): + parent = _ds("p", "/", neutralize=True) + rootIf, _ = self._makeRootIf([parent]) + with self.assertRaises(ValueError): + _inheritFlags.cascadeResetDescendants(rootIf, parent, "unknownFlag") + + +def _fds(idVal: str, *, tableName: str, recordFilter=None, **flags) -> dict: + """Build a FeatureDataSource dict fixture.""" + base = { + "id": idVal, + "workspaceInstanceId": "ws-1", + "tableName": tableName, + "recordFilter": recordFilter, + "neutralize": None, + "scope": None, + } + base.update(flags) + return base + + +class TestFdsClassifyAndAncestry(unittest.TestCase): + def test_classify_workspace_wildcard(self): + self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="*")), "workspace") + + def test_classify_table_wildcard(self): + self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="Pos")), "table") + + def test_classify_record_specific(self): + rec = _fds("a", tableName="Pos", recordFilter={"id": "r-1"}) + self.assertEqual(_inheritFlags._fdsClassify(rec), "record") + + def test_workspace_is_ancestor_of_table_and_record(self): + ws = _fds("ws", tableName="*") + tbl = _fds("t", tableName="Pos") + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + self.assertTrue(_inheritFlags._fdsIsAncestor(ws, tbl)) + self.assertTrue(_inheritFlags._fdsIsAncestor(ws, rec)) + + def test_table_is_ancestor_of_record_same_table_only(self): + tbl = _fds("t", tableName="Pos") + recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"}) + recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}) + self.assertTrue(_inheritFlags._fdsIsAncestor(tbl, recSame)) + self.assertFalse(_inheritFlags._fdsIsAncestor(tbl, recOther)) + + def test_record_has_no_descendants(self): + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + tbl = _fds("t", tableName="Pos") + self.assertFalse(_inheritFlags._fdsIsAncestor(rec, tbl)) + + def test_no_cross_workspace_ancestry(self): + ws = _fds("ws", tableName="*", workspaceInstanceId="ws-A") + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, workspaceInstanceId="ws-B") + self.assertFalse(_inheritFlags._fdsIsAncestor(ws, rec)) + + +class TestFdsEffectiveFlag(unittest.TestCase): + def test_own_explicit_wins(self): + ws = _fds("ws", tableName="*", neutralize=False) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) + self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [ws, rec])) + + def test_inherits_from_table_wildcard(self): + tbl = _fds("t", tableName="Pos", neutralize=True) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [tbl, rec])) + + def test_table_wildcard_beats_workspace_wildcard(self): + ws = _fds("ws", tableName="*", neutralize=False) + tbl = _fds("t", tableName="Pos", neutralize=True) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [ws, tbl, rec])) + + def test_workspace_wildcard_inherits_when_no_table(self): + ws = _fds("ws", tableName="*", neutralize=True) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [ws, rec])) + + def test_default_false_when_chain_empty(self): + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + self.assertFalse(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [rec])) + + def test_unknown_flag_raises(self): + rec = _fds("r", tableName="*") + with self.assertRaises(ValueError): + _inheritFlags.getEffectiveFlagFds(rec, "ragIndexEnabled", [rec]) + + +class TestFdsCascadeReset(unittest.TestCase): + def _makeRootIf(self, fdses): + rootIf = MagicMock() + rootIf.db.getRecordset = MagicMock(return_value=fdses) + modified = [] + + def _modify(model, recordId, fields): + modified.append((recordId, fields)) + rootIf.db.recordModify = MagicMock(side_effect=_modify) + return rootIf, modified + + def test_workspace_cascades_to_all_explicit_descendants(self): + ws = _fds("ws", tableName="*", neutralize=True) + tblExplicit = _fds("t", tableName="Pos", neutralize=False) + tblInherit = _fds("t2", tableName="Other") + recExplicit = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) + rootIf, modified = self._makeRootIf([ws, tblExplicit, tblInherit, recExplicit]) + + affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") + + self.assertEqual(affected, 2) + self.assertEqual({m[0] for m in modified}, {"t", "r"}) + + def test_table_cascades_only_to_same_table_records(self): + tbl = _fds("t", tableName="Pos", neutralize=True) + recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"}, neutralize=False) + recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}, neutralize=False) + rootIf, modified = self._makeRootIf([tbl, recSame, recOther]) + + affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize") + + self.assertEqual(affected, 1) + self.assertEqual(modified, [("r1", {"neutralize": None})]) + + def test_record_has_no_cascade(self): + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) + rootIf, modified = self._makeRootIf([rec]) + affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize") + self.assertEqual(affected, 0) + self.assertEqual(modified, []) + + def test_unknown_flag_raises(self): + ws = _fds("ws", tableName="*", neutralize=True) + rootIf, _ = self._makeRootIf([ws]) + with self.assertRaises(ValueError): + _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled") + + +class TestPathNormalization(unittest.TestCase): + def test_empty_path_normalises_to_root(self): + self.assertEqual(_inheritFlags._normalisePath(""), "/") + self.assertEqual(_inheritFlags._normalisePath(None), "/") + + def test_trailing_slash_stripped(self): + self.assertEqual(_inheritFlags._normalisePath("/foo/"), "/foo") + self.assertEqual(_inheritFlags._normalisePath("/"), "/") + + def test_leading_slash_added(self): + self.assertEqual(_inheritFlags._normalisePath("foo/bar"), "/foo/bar") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/services/test_knowledge_ingest_consumer.py b/tests/unit/services/test_knowledge_ingest_consumer.py index 6b27a6e8..9884079e 100644 --- a/tests/unit/services/test_knowledge_ingest_consumer.py +++ b/tests/unit/services/test_knowledge_ingest_consumer.py @@ -99,11 +99,18 @@ def test_onConnectionRevoked_ignores_missing_id(monkeypatch): assert seen == [] +def _stubRagEnabledDs(monkeypatch, dataSources): + """Stub _loadRagEnabledDataSources so tests don't need a live DB.""" + monkeypatch.setattr(consumer, "_loadRagEnabledDataSources", lambda *_, **__: dataSources) + + def test_bootstrap_job_skips_unsupported_authority(monkeypatch): + _stubRagEnabledDs(monkeypatch, [{"id": "ds1", "sourceType": "unknownType"}]) + async def _run(): result = await consumer._bootstrapJobHandler( {"payload": {"connectionId": "c1", "authority": "slack"}}, - lambda *_: None, + lambda *_, **__: None, ) return result @@ -114,13 +121,18 @@ def test_bootstrap_job_skips_unsupported_authority(monkeypatch): def test_bootstrap_job_dispatches_msft_parts(monkeypatch): + _stubRagEnabledDs(monkeypatch, [ + {"id": "ds1", "sourceType": "sharepointFolder"}, + {"id": "ds2", "sourceType": "outlookFolder"}, + ]) + calls = {"sp": 0, "ol": 0} - async def _fakeSp(connectionId, progressCb=None): + async def _fakeSp(connectionId, progressCb=None, dataSources=None): calls["sp"] += 1 return {"indexed": 1} - async def _fakeOl(connectionId, progressCb=None): + async def _fakeOl(connectionId, progressCb=None, dataSources=None): calls["ol"] += 1 return {"indexed": 2} @@ -142,7 +154,7 @@ def test_bootstrap_job_dispatches_msft_parts(monkeypatch): async def _run(): return await consumer._bootstrapJobHandler( {"payload": {"connectionId": "c1", "authority": "msft"}}, - lambda *_: None, + lambda *_, **__: None, ) result = asyncio.run(_run()) @@ -152,13 +164,18 @@ def test_bootstrap_job_dispatches_msft_parts(monkeypatch): def test_bootstrap_job_dispatches_google_parts(monkeypatch): + _stubRagEnabledDs(monkeypatch, [ + {"id": "ds1", "sourceType": "googleDriveFolder"}, + {"id": "ds2", "sourceType": "gmailFolder"}, + ]) + calls = {"gd": 0, "gm": 0} - async def _fakeGd(connectionId, progressCb=None): + async def _fakeGd(connectionId, progressCb=None, dataSources=None): calls["gd"] += 1 return {"indexed": 7} - async def _fakeGm(connectionId, progressCb=None): + async def _fakeGm(connectionId, progressCb=None, dataSources=None): calls["gm"] += 1 return {"indexed": 11} @@ -180,7 +197,7 @@ def test_bootstrap_job_dispatches_google_parts(monkeypatch): async def _run(): return await consumer._bootstrapJobHandler( {"payload": {"connectionId": "c1", "authority": "google"}}, - lambda *_: None, + lambda *_, **__: None, ) result = asyncio.run(_run()) @@ -190,9 +207,13 @@ def test_bootstrap_job_dispatches_google_parts(monkeypatch): def test_bootstrap_job_dispatches_clickup_part(monkeypatch): + _stubRagEnabledDs(monkeypatch, [ + {"id": "ds1", "sourceType": "clickupList"}, + ]) + calls = {"cu": 0} - async def _fakeCu(connectionId, progressCb=None): + async def _fakeCu(connectionId, progressCb=None, dataSources=None): calls["cu"] += 1 return {"indexed": 4} @@ -207,7 +228,7 @@ def test_bootstrap_job_dispatches_clickup_part(monkeypatch): async def _run(): return await consumer._bootstrapJobHandler( {"payload": {"connectionId": "c1", "authority": "clickup"}}, - lambda *_: None, + lambda *_, **__: None, ) result = asyncio.run(_run()) diff --git a/tests/unit/services/test_ragLimits.py b/tests/unit/services/test_ragLimits.py new file mode 100644 index 00000000..bb336ed3 --- /dev/null +++ b/tests/unit/services/test_ragLimits.py @@ -0,0 +1,79 @@ +"""Unit tests for `_ragLimits` central helpers. + +Verifies: +- defaults are returned as fresh copies (no mutation leakage) +- getStoredOverrides returns ONLY explicit overrides (walker contract) +- getRagLimits merges defaults with overrides (API/cost-estimate contract) +- non-int values in stored settings are dropped, not silently coerced +""" + +from __future__ import annotations + +import unittest + +from modules.serviceCenter.services.serviceKnowledge import _ragLimits + + +class TestGetDefaults(unittest.TestCase): + def test_files_defaults_have_all_keys(self): + d = _ragLimits.getDefaults("files") + self.assertEqual(set(d.keys()), {"maxItems", "maxBytes", "maxFileSize", "maxDepth"}) + self.assertEqual(d["maxBytes"], 200 * 1024 * 1024) + + def test_clickup_defaults(self): + d = _ragLimits.getDefaults("clickup") + self.assertEqual(set(d.keys()), {"maxTasks", "maxWorkspaces", "maxListsPerWorkspace"}) + + def test_defaults_are_a_fresh_copy(self): + d1 = _ragLimits.getDefaults("files") + d1["maxBytes"] = 1 + d2 = _ragLimits.getDefaults("files") + self.assertEqual(d2["maxBytes"], 200 * 1024 * 1024) + + def test_unknown_kind_raises(self): + with self.assertRaises(ValueError): + _ragLimits.getDefaults("unknown") + + +class TestGetStoredOverrides(unittest.TestCase): + def test_no_settings_returns_empty_dict(self): + self.assertEqual(_ragLimits.getStoredOverrides({"id": "x", "settings": None}, "files"), {}) + + def test_only_explicit_overrides_returned(self): + ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999}}} + self.assertEqual(_ragLimits.getStoredOverrides(ds, "files"), {"maxBytes": 999}) + + def test_unknown_keys_dropped(self): + ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999, "bogus": 1}}} + self.assertEqual(_ragLimits.getStoredOverrides(ds, "files"), {"maxBytes": 999}) + + def test_non_int_dropped(self): + ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": "not-a-number"}}} + self.assertEqual(_ragLimits.getStoredOverrides(ds, "files"), {}) + + def test_none_or_garbage_settings_safe(self): + self.assertEqual(_ragLimits.getStoredOverrides(None, "files"), {}) + self.assertEqual(_ragLimits.getStoredOverrides({"id": "x", "settings": "garbage"}, "files"), {}) + + +class TestGetRagLimits(unittest.TestCase): + def test_no_settings_returns_defaults(self): + result = _ragLimits.getRagLimits({"id": "x", "settings": None}, "files") + self.assertEqual(result, _ragLimits.FILES_LIMITS_DEFAULT) + + def test_partial_override_merges_with_defaults(self): + ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999}}} + result = _ragLimits.getRagLimits(ds, "files") + self.assertEqual(result["maxBytes"], 999) + self.assertEqual(result["maxItems"], _ragLimits.FILES_LIMITS_DEFAULT["maxItems"]) + + def test_caller_can_distinguish_unset_from_set(self): + """Walker contract: an unset key MUST NOT appear in `getStoredOverrides`.""" + ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999}}} + overrides = _ragLimits.getStoredOverrides(ds, "files") + self.assertIn("maxBytes", overrides) + self.assertNotIn("maxItems", overrides) + + +if __name__ == "__main__": + unittest.main() From 1ed462ad13742c7bf4b52a5abb7a83395a2ddb4c Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Tue, 19 May 2026 16:48:01 +0200 Subject: [PATCH 3/6] fixes rag and workflow --- modules/connectors/connectorDbPostgre.py | 16 +- .../datamodels/datamodelFeatureDataSource.py | 8 + modules/demoConfigs/investorDemo2026.py | 25 +- modules/demoConfigs/pwgDemo2026.py | 23 +- .../workspace/datamodelFeatureWorkspace.py | 9 +- .../workspace/routeFeatureWorkspace.py | 625 +++------- modules/routes/routeAdminDemoConfig.py | 14 +- modules/routes/routeDataConnections.py | 7 +- modules/routes/routeDataFiles.py | 261 ++++- modules/routes/routeDataSources.py | 194 +++- modules/routes/routeRagInventory.py | 271 ++++- modules/security/rbac.py | 16 +- .../coreTools/_dataSourceTools.py | 7 +- .../coreTools/_featureSubAgentTools.py | 6 +- .../serviceAgent/featureDataProvider.py | 42 +- .../services/serviceKnowledge/_buildTree.py | 1020 +++++++++++++++++ .../serviceKnowledge/_inheritFlags.py | 503 ++++++-- .../serviceKnowledge/mainServiceKnowledge.py | 2 +- .../subConnectorIngestConsumer.py | 11 +- .../serviceKnowledge/subFeatureBootstrap.py | 289 +++++ .../serviceKnowledge/subPolicyResolver.py | 32 - .../serviceKnowledge/subWalkerHelpers.py | 17 +- scripts/script_migrate_user_uid.py | 274 +++++ .../test_connectorDbPostgre_failLoud.py | 10 + tests/unit/services/test_buildTree.py | 359 ++++++ tests/unit/services/test_inheritFlags.py | 525 +++++++-- tests/unit/teamsbot/test_directorPrompts.py | 10 +- 27 files changed, 3728 insertions(+), 848 deletions(-) create mode 100644 modules/serviceCenter/services/serviceKnowledge/_buildTree.py create mode 100644 modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py delete mode 100644 modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py create mode 100644 scripts/script_migrate_user_uid.py create mode 100644 tests/unit/services/test_buildTree.py diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index f1a34f70..fa4cba44 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -172,7 +172,7 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s pass # already a list elif fieldType == "BOOLEAN": - record[fieldName] = bool(value) if value is not None else False + record[fieldName] = bool(value) if value is not None else None elif fieldType == "JSONB" and value is not None: try: @@ -184,6 +184,18 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})") +def _stripNulBytesFromStr(value: Any) -> Any: + """psycopg2 rejects bound parameters whose Python str contains NUL (0x00). + + Some extracted files (e.g. SQL dumps, mixed binary treated as text) still + carry those bytes; PostgreSQL TEXT could store them via other paths, but + the client protocol path used here cannot. + """ + if isinstance(value, str) and "\x00" in value: + return value.replace("\x00", "") + return value + + def _quotePgIdent(name: str) -> str: return '"' + str(name).replace('"', '""') + '"' @@ -983,7 +995,7 @@ class DatabaseConnector: else: value = json.dumps(value) - values.append(value) + values.append(_stripNulBytesFromStr(value)) # Build INSERT/UPDATE with quoted identifiers col_names = ", ".join([f'"{col}"' for col in columns]) diff --git a/modules/datamodels/datamodelFeatureDataSource.py b/modules/datamodels/datamodelFeatureDataSource.py index f07a8bda..10fd76a7 100644 --- a/modules/datamodels/datamodelFeatureDataSource.py +++ b/modules/datamodels/datamodelFeatureDataSource.py @@ -76,6 +76,14 @@ class FeatureDataSource(PowerOnModel): ), json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, ) + ragIndexEnabled: Optional[bool] = Field( + default=None, + description=( + "Three-state RAG-indexing flag with cascade-inherit semantics. " + "None = inherit; True/False = explicit. Cascade-reset on parent toggle." + ), + json_schema_extra={"label": "RAG-Indexierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, + ) neutralizeFields: Optional[List[str]] = Field( default=None, description="Column names whose values are replaced with placeholders before AI processing", diff --git a/modules/demoConfigs/investorDemo2026.py b/modules/demoConfigs/investorDemo2026.py index f8fc678f..d807921d 100644 --- a/modules/demoConfigs/investorDemo2026.py +++ b/modules/demoConfigs/investorDemo2026.py @@ -124,6 +124,7 @@ class InvestorDemo2026(_BaseDemoConfig): from modules.datamodels.datamodelUam import Mandate, UserInDB from modules.datamodels.datamodelMembership import UserMandate + summary["_removedMandateIds"] = [] for mandateDef in [_MANDATE_HAPPYLIFE, _MANDATE_ALPINA]: try: existing = db.getRecordset(Mandate, recordFilter={"name": mandateDef["name"]}) @@ -132,28 +133,36 @@ class InvestorDemo2026(_BaseDemoConfig): self._removeMandateData(db, mid, mandateDef["label"], summary) db.recordDelete(Mandate, mid) summary["removed"].append(f"Mandate {mandateDef['label']} ({mid})") + summary["_removedMandateIds"].append({"id": mid, "mandateId": mid}) logger.info(f"Removed mandate {mandateDef['label']} ({mid})") except Exception as e: summary["errors"].append(f"Remove mandate {mandateDef['label']}: {e}") + # SAFETY: NEVER delete the user record. The user may have connections, + # chats, workflows, files, and other data across multiple databases. + # Only remove the mandate memberships that THIS demo created. try: existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]}) for u in existing: uid = u.get("id") + removedMandateIds = {m.get("mandateId") for m in summary.get("_removedMandateIds", [])} memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) for mem in memberships: - try: - db.recordDelete(UserMandate, mem.get("id")) - except Exception: - pass - db.recordDelete(UserInDB, uid) - summary["removed"].append(f"User {_USER['username']} ({uid})") - logger.info(f"Removed user {_USER['username']} ({uid})") + if mem.get("mandateId") in removedMandateIds: + try: + db.recordDelete(UserMandate, mem.get("id")) + except Exception: + pass + summary["skipped"].append( + f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)" + ) + logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only") except Exception as e: - summary["errors"].append(f"Remove user: {e}") + summary["errors"].append(f"Remove user memberships: {e}") self._removeLanguageSet(db, "es", summary) + summary.pop("_removedMandateIds", None) return summary # ------------------------------------------------------------------ diff --git a/modules/demoConfigs/pwgDemo2026.py b/modules/demoConfigs/pwgDemo2026.py index f0dc5e6d..4a6491a3 100644 --- a/modules/demoConfigs/pwgDemo2026.py +++ b/modules/demoConfigs/pwgDemo2026.py @@ -121,32 +121,39 @@ class PwgDemo2026(_BaseDemoConfig): from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelUam import Mandate, UserInDB + removedMandateIds = set() try: existing = db.getRecordset(Mandate, recordFilter={"name": _MANDATE_PWG["name"]}) for m in existing: mid = m.get("id") self._removeMandateData(db, mid, _MANDATE_PWG["label"], summary) db.recordDelete(Mandate, mid) + removedMandateIds.add(mid) summary["removed"].append(f"Mandate {_MANDATE_PWG['label']} ({mid})") logger.info(f"Removed mandate {_MANDATE_PWG['label']} ({mid})") except Exception as e: summary["errors"].append(f"Remove mandate {_MANDATE_PWG['label']}: {e}") + # SAFETY: NEVER delete the user record. The user may have connections, + # chats, workflows, files, and other data across multiple databases. + # Only remove the mandate memberships that THIS demo created. try: existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]}) for u in existing: uid = u.get("id") memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) or [] for mem in memberships: - try: - db.recordDelete(UserMandate, mem.get("id")) - except Exception: - pass - db.recordDelete(UserInDB, uid) - summary["removed"].append(f"User {_USER['username']} ({uid})") - logger.info(f"Removed user {_USER['username']} ({uid})") + if mem.get("mandateId") in removedMandateIds: + try: + db.recordDelete(UserMandate, mem.get("id")) + except Exception: + pass + summary["skipped"].append( + f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)" + ) + logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only") except Exception as e: - summary["errors"].append(f"Remove user: {e}") + summary["errors"].append(f"Remove user memberships: {e}") return summary diff --git a/modules/features/workspace/datamodelFeatureWorkspace.py b/modules/features/workspace/datamodelFeatureWorkspace.py index 4e32702c..d0ba8815 100644 --- a/modules/features/workspace/datamodelFeatureWorkspace.py +++ b/modules/features/workspace/datamodelFeatureWorkspace.py @@ -2,7 +2,7 @@ # All rights reserved. """Workspace feature data models — WorkspaceUserSettings.""" -from typing import List, Optional +from typing import Dict, List, Optional from pydantic import Field from modules.datamodels.datamodelBase import PowerOnModel from modules.shared.i18nRegistry import i18nModel @@ -52,7 +52,7 @@ class WorkspaceUserSettings(PowerOnModel): description="Max agent rounds override (None = instance default)", json_schema_extra={"label": "Max. Agenten-Runden", "frontend_type": "number", "frontend_readonly": False, "frontend_required": False}, ) - requireNeutralization: bool = Field( + requireNeutralization: Optional[bool] = Field( default=False, description="Default neutralization setting for this user", json_schema_extra={"label": "Neutralisierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, @@ -67,3 +67,8 @@ class WorkspaceUserSettings(PowerOnModel): description="Allowed AI models (empty = all permitted)", json_schema_extra={"label": "Erlaubte Modelle", "frontend_type": "modelMultiSelect", "frontend_readonly": False, "frontend_required": False}, ) + uiTreeExpansion: Dict[str, List[str]] = Field( + default_factory=dict, + description="Per-tab expanded tree-node ids for the UDB / FormGeneratorTree. Key = scope name (e.g. 'sources', 'filesOwn', 'filesShared').", + json_schema_extra={"label": "Tree-Expand-Zustand", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False}, + ) diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py index 2fa788e8..5c24c113 100644 --- a/modules/features/workspace/routeFeatureWorkspace.py +++ b/modules/features/workspace/routeFeatureWorkspace.py @@ -1281,52 +1281,101 @@ async def listWorkspaceDataSources( try: from modules.datamodels.datamodelDataSource import DataSource from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByConnection rootIf = getRootInterface() recordFilter: dict = {"featureInstanceId": instanceId} if wsMandateId: recordFilter["mandateId"] = wsMandateId dataSources = rootIf.db.getRecordset(DataSource, recordFilter=recordFilter) - return JSONResponse({"dataSources": dataSources or []}) + if not dataSources: + return JSONResponse({"dataSources": []}) + + # Group by connectionId and compute effective values in aggregate mode + byConnection: dict = {} + for ds in dataSources: + connId = ds.get("connectionId") or "" + byConnection.setdefault(connId, []).append(ds) + + for connDs in byConnection.values(): + effNeutralize = buildEffectiveByConnection(connDs, "neutralize", mode="aggregate") + effScope = buildEffectiveByConnection(connDs, "scope", mode="aggregate") + effRag = buildEffectiveByConnection(connDs, "ragIndexEnabled", mode="aggregate") + for ds in connDs: + dsId = ds.get("id", "") + ds["effectiveNeutralize"] = effNeutralize.get(dsId, False) + ds["effectiveScope"] = effScope.get(dsId, "personal") + ds["effectiveRagIndexEnabled"] = effRag.get(dsId, False) + + return JSONResponse({"dataSources": dataSources}) except Exception: return JSONResponse({"dataSources": []}) -@router.get("/{instanceId}/connections") +class _TreeChildrenRequest(BaseModel): + """Request body for the generic tree children endpoint.""" + parents: List[Optional[str]] = Field( + default_factory=list, + description="List of parent keys to fetch children for. Use null for top-level.", + ) + + +@router.post("/{instanceId}/tree/children") @limiter.limit("300/minute") -async def listWorkspaceConnections( +async def getTreeChildren( request: Request, instanceId: str = Path(...), + body: _TreeChildrenRequest = Body(...), context: RequestContext = Depends(getRequestContext), ): - """Return the user's active connections (UserConnections).""" - _mandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.serviceCenter import getService - from modules.serviceCenter.context import ServiceCenterContext - ctx = ServiceCenterContext( - user=context.user, - mandate_id=_mandateId or "", - feature_instance_id=instanceId, + """Generic UDB tree children resolver. + + The UI sends a list of parent keys (or null for top-level). The backend + returns children for each requested parent, with all effective flag + values pre-computed. The UI builds the visible tree from the resulting + flat per-parent map. + """ + _validateInstanceAccess(instanceId, context) + from modules.serviceCenter.services.serviceKnowledge._buildTree import getChildrenForParents + + try: + nodesByParent = await getChildrenForParents(instanceId, body.parents, context) + except Exception as exc: + logger.exception("Tree children build failed: %s", exc) + raise HTTPException(status_code=500, detail=str(exc)) + return JSONResponse({"nodesByParent": nodesByParent}) + + +class _TreeAttributesRequest(BaseModel): + """Request body for the attribute-refresh endpoint.""" + keys: List[str] = Field( + default_factory=list, + description="List of node keys to fetch current attributes for.", ) - chatService = getService("chat", ctx) - connections = chatService.getUserConnections() - items = [] - for c in connections or []: - conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {}) - authority = conn.get("authority") - if hasattr(authority, "value"): - authority = authority.value - status = conn.get("status") - if hasattr(status, "value"): - status = status.value - items.append({ - "id": conn.get("id"), - "authority": authority, - "externalUsername": conn.get("externalUsername"), - "externalEmail": conn.get("externalEmail"), - "status": status, - "knowledgeIngestionEnabled": bool(conn.get("knowledgeIngestionEnabled")), - }) - return JSONResponse({"connections": items}) + + +@router.post("/{instanceId}/tree/attributes") +@limiter.limit("300/minute") +async def getTreeAttributes( + request: Request, + instanceId: str = Path(...), + body: _TreeAttributesRequest = Body(...), + context: RequestContext = Depends(getRequestContext), +): + """Return current effective attribute values (neutralize, scope, + ragIndexEnabled) for a list of node keys. Used after a toggle action + to refresh only the visible nodes without reloading tree structure.""" + _validateInstanceAccess(instanceId, context) + from modules.serviceCenter.services.serviceKnowledge._buildTree import getAttributesForKeys + + if len(body.keys) > 500: + raise HTTPException(status_code=400, detail="Max 500 keys per request") + + try: + attrs = await getAttributesForKeys(instanceId, body.keys, context) + except Exception as exc: + logger.exception("Tree attributes failed: %s", exc) + raise HTTPException(status_code=500, detail=str(exc)) + return JSONResponse({"attributes": attrs}) class CreateDataSourceRequest(BaseModel): @@ -1391,303 +1440,6 @@ async def deleteWorkspaceDataSource( # ---- Feature Connections & Feature Data Sources ---- -@router.get("/{instanceId}/feature-connections") -@limiter.limit("120/minute") -async def listFeatureConnections( - request: Request, - instanceId: str = Path(...), - context: RequestContext = Depends(getRequestContext), -): - """List feature instances the user has access to, scoped to the workspace mandate.""" - wsMandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.security.rbacCatalog import getCatalogService - from modules.datamodels.datamodelUam import Mandate - - rootIf = getRootInterface() - userId = str(context.user.id) - - catalog = getCatalogService() - featureCodesWithData = catalog.getFeaturesWithDataObjects() - - userMandates = rootIf.getUserMandates(userId) - if not userMandates: - return JSONResponse({"featureConnectionsByMandate": []}) - - allowedMandateIds = {um.mandateId for um in userMandates} - if wsMandateId and wsMandateId in allowedMandateIds: - allowedMandateIds = {wsMandateId} - - mandateLabels: dict = {} - for um in userMandates: - if um.mandateId not in allowedMandateIds: - continue - try: - rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId}) - if rows: - m = rows[0] - mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId - except Exception: - mandateLabels[um.mandateId] = um.mandateId - - byMandate: dict = {} - seenIds: set = set() - for um in userMandates: - if um.mandateId not in allowedMandateIds: - continue - allInstances = rootIf.getFeatureInstancesByMandate(um.mandateId) - for inst in allInstances: - if inst.id in seenIds: - continue - seenIds.add(inst.id) - if not inst.enabled: - continue - if inst.featureCode not in featureCodesWithData: - continue - featureAccess = rootIf.getFeatureAccess(userId, inst.id) - if not featureAccess or not featureAccess.enabled: - continue - - featureDef = catalog.getFeatureDefinition(inst.featureCode) or {} - dataObjects = catalog.getDataObjects(inst.featureCode) - label = inst.label or inst.featureCode - mid = inst.mandateId - connItem = { - "featureInstanceId": inst.id, - "featureCode": inst.featureCode, - "mandateId": mid, - "label": label, - "icon": featureDef.get("icon", "mdi-database"), - "tableCount": len(dataObjects), - } - if mid not in byMandate: - byMandate[mid] = [] - byMandate[mid].append(connItem) - - def _sortKeyLabel(x: dict) -> str: - return (x.get("label") or "").lower() - - groups = [] - for mid in sorted(byMandate.keys(), key=lambda m: (mandateLabels.get(m, m) or "").lower()): - conns = sorted(byMandate[mid], key=_sortKeyLabel) - groups.append({ - "mandateId": mid, - "mandateLabel": mandateLabels.get(mid, mid), - "featureConnections": conns, - }) - - return JSONResponse({"featureConnectionsByMandate": groups}) - - -@router.get("/{instanceId}/feature-connections/{fiId}/tables") -@limiter.limit("120/minute") -async def listFeatureConnectionTables( - request: Request, - instanceId: str = Path(...), - fiId: str = Path(..., description="Feature instance ID"), - context: RequestContext = Depends(getRequestContext), -): - """List data tables (DATA_OBJECTS) for a feature instance, filtered by RBAC.""" - wsMandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.security.rbacCatalog import getCatalogService - - rootIf = getRootInterface() - inst = rootIf.getFeatureInstance(fiId) - if not inst: - raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found")) - - mandateId = str(inst.mandateId) if inst.mandateId else None - if wsMandateId and mandateId and mandateId != wsMandateId: - raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate")) - catalog = getCatalogService() - - try: - from modules.security.rbac import RbacClass - from modules.security.rootAccess import getRootDbAppConnector - dbApp = getRootDbAppConnector() - rbac = RbacClass(dbApp, dbApp=dbApp) - accessible = catalog.getAccessibleDataObjects( - featureCode=inst.featureCode, - rbacInstance=rbac, - user=context.user, - mandateId=mandateId or "", - featureInstanceId=fiId, - ) - except Exception: - accessible = catalog.getDataObjects(inst.featureCode) - - accessibleKeys = {obj.get("objectKey", "") for obj in accessible} - referencedGroups = set() - for obj in accessible: - meta = obj.get("meta", {}) - if meta.get("wildcard") or meta.get("isGroup"): - continue - if meta.get("group"): - referencedGroups.add(meta["group"]) - - tables = [] - for obj in catalog.getDataObjects(inst.featureCode): - meta = obj.get("meta", {}) - if meta.get("wildcard"): - continue - objectKey = obj.get("objectKey", "") - if meta.get("isGroup"): - # Groups are metadata-only; include if at least one child is accessible - # (regardless of whether the group itself was RBAC-granted). - if objectKey not in referencedGroups: - continue - else: - if objectKey not in accessibleKeys: - continue - node = { - "objectKey": objectKey, - "tableName": meta.get("table", ""), - "label": resolveText(obj.get("label", "")), - "fields": meta.get("fields", []), - "isParent": bool(meta.get("isParent", False)), - "parentTable": meta.get("parentTable") or None, - "parentKey": meta.get("parentKey") or None, - "displayFields": meta.get("displayFields", []), - "isGroup": bool(meta.get("isGroup", False)), - "group": meta.get("group") or None, - } - tables.append(node) - - return JSONResponse({"tables": tables}) - - -@router.get("/{instanceId}/feature-connections/{fiId}/parent-objects/{tableName}") -@limiter.limit("120/minute") -async def listParentObjects( - request: Request, - instanceId: str = Path(...), - fiId: str = Path(..., description="Feature instance ID"), - tableName: str = Path(..., description="Parent table name from DATA_OBJECTS"), - parentKey: Optional[str] = Query(None, description="Optional FK column name to filter by ancestor record (nested parent rendering)"), - parentValue: Optional[str] = Query(None, description="Optional FK value matching parentKey to filter children of a specific ancestor record"), - context: RequestContext = Depends(getRequestContext), -): - """List records from a parent table so the user can pick a specific record to scope data. - - When parentKey + parentValue are provided, results are additionally filtered by that FK, - enabling nested record hierarchies (e.g. Sessions OF Context X). - """ - wsMandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.security.rbacCatalog import getCatalogService - - rootIf = getRootInterface() - inst = rootIf.getFeatureInstance(fiId) - if not inst: - raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found")) - - featureCode = inst.featureCode - mandateId = str(inst.mandateId) if inst.mandateId else "" - if wsMandateId and mandateId and mandateId != wsMandateId: - raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate")) - catalog = getCatalogService() - - parentObj = None - for obj in catalog.getDataObjects(featureCode): - meta = obj.get("meta", {}) - if meta.get("table") == tableName and meta.get("isParent"): - parentObj = obj - break - if not parentObj: - raise HTTPException(status_code=400, detail=f"Table '{tableName}' is not a registered parent table") - - displayFields = parentObj["meta"].get("displayFields", []) - selectCols = ', '.join(f'"{f}"' for f in (["id"] + displayFields)) if displayFields else "*" - - from modules.connectors.connectorDbPostgre import DatabaseConnector - from modules.shared.configuration import APP_CONFIG - featureDbName = f"poweron_{featureCode.lower()}" - featureDbConn = None - try: - featureDbConn = DatabaseConnector( - dbHost=APP_CONFIG.get("DB_HOST", "localhost"), - dbDatabase=featureDbName, - dbUser=APP_CONFIG.get("DB_USER"), - dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), - dbPort=int(APP_CONFIG.get("DB_PORT", 5432)), - userId=str(context.user.id), - ) - conn = featureDbConn.connection - with conn.cursor() as cur: - cur.execute( - "SELECT column_name FROM information_schema.columns " - "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " - "AND column_name IN ('featureInstanceId', 'instanceId')", - [tableName], - ) - instanceCols = [row["column_name"] for row in cur.fetchall()] - instanceCol = "featureInstanceId" if "featureInstanceId" in instanceCols else "instanceId" - - cur.execute( - "SELECT column_name FROM information_schema.columns " - "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " - "AND column_name = 'userId'", - [tableName], - ) - hasUserId = cur.rowcount > 0 - - sql = ( - f'SELECT {selectCols} FROM "{tableName}" ' - f'WHERE "{instanceCol}" = %s' - ) - params = [fiId] - if mandateId: - sql += ' AND "mandateId" = %s' - params.append(mandateId) - if hasUserId: - sql += ' AND "userId" = %s' - params.append(str(context.user.id)) - - if parentKey and parentValue: - cur.execute( - "SELECT 1 FROM information_schema.columns " - "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " - "AND column_name = %s", - [tableName, parentKey], - ) - if cur.rowcount > 0: - sql += f' AND "{parentKey}" = %s' - params.append(parentValue) - else: - logger.warning( - f"listParentObjects({tableName}): ignoring parentKey '{parentKey}' (column does not exist)" - ) - - sql += ' ORDER BY "id" DESC LIMIT 100' - cur.execute(sql, params) - rows = [] - for row in cur.fetchall(): - r = dict(row) - for k, v in r.items(): - if hasattr(v, "isoformat"): - r[k] = v.isoformat() - elif isinstance(v, (bytes, bytearray)): - r[k] = f"" - displayParts = [str(r.get(f, "")) for f in displayFields if r.get(f) is not None] - rows.append({ - "id": r.get("id", ""), - "displayLabel": " | ".join(displayParts) if displayParts else r.get("id", ""), - "fields": {f: r.get(f) for f in displayFields}, - }) - except Exception as e: - logger.error(f"listParentObjects({tableName}) failed: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to list parent objects: {e}") - finally: - if featureDbConn: - try: - featureDbConn.close() - except Exception: - pass - - return JSONResponse({"parentObjects": rows}) - - class CreateFeatureDataSourceRequest(BaseModel): """Request body for adding a feature table as data source.""" featureInstanceId: str = Field(description="Feature instance ID") @@ -1706,16 +1458,35 @@ async def createFeatureDataSource( body: CreateFeatureDataSourceRequest = Body(...), context: RequestContext = Depends(getRequestContext), ): - """Create a FeatureDataSource for this workspace instance.""" + """Create a FeatureDataSource for this workspace instance. + + The FDS lives under the WORKSPACE's mandate (not the feature's): that + matches how the tree (`allFds = recordset where workspaceInstanceId = + instanceId`) and the PATCH endpoints scope these records — by workspace, + not by feature mandate. The user can legitimately reference a feature + from another mandate they have access to (via the UDB mandate-group + nodes), and a hard cross-mandate block here would silently 403 those + toggles. Access to the referenced feature is verified by the user's + `FeatureAccess` and the existing tree-children RBAC, which run before + the user can ever click on this node. + """ wsMandateId, _ = _validateInstanceAccess(instanceId, context) from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource rootIf = getRootInterface() - inst = rootIf.getFeatureInstance(body.featureInstanceId) - mandateId = str(inst.mandateId) if inst else (str(context.mandateId) if context.mandateId else "") - if wsMandateId and mandateId and mandateId != wsMandateId: - raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate")) + if not rootIf.getFeatureAccess(str(context.user.id), body.featureInstanceId): + raise HTTPException(status_code=403, detail=routeApiMsg("Access denied to this feature instance")) + + existing = rootIf.db.getRecordset(FeatureDataSource, recordFilter={ + "workspaceInstanceId": instanceId, + "featureInstanceId": body.featureInstanceId, + "tableName": body.tableName, + }) or [] + targetFilter = body.recordFilter or None + for rec in existing: + if (rec.get("recordFilter") or None) == targetFilter: + return JSONResponse(rec) fds = FeatureDataSource( featureInstanceId=body.featureInstanceId, @@ -1723,7 +1494,7 @@ async def createFeatureDataSource( tableName=body.tableName, objectKey=body.objectKey, label=body.label, - mandateId=mandateId, + mandateId=wsMandateId or "", userId=str(context.user.id), workspaceInstanceId=instanceId, recordFilter=body.recordFilter, @@ -1743,13 +1514,26 @@ async def listFeatureDataSources( wsMandateId, _ = _validateInstanceAccess(instanceId, context) from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByWorkspaceFds rootIf = getRootInterface() recordFilter: dict = {"workspaceInstanceId": instanceId} if wsMandateId: recordFilter["mandateId"] = wsMandateId records = rootIf.db.getRecordset(FeatureDataSource, recordFilter=recordFilter) - return JSONResponse({"featureDataSources": records or []}) + if not records: + return JSONResponse({"featureDataSources": []}) + + effNeutralize = buildEffectiveByWorkspaceFds(records, "neutralize", mode="aggregate") + effScope = buildEffectiveByWorkspaceFds(records, "scope", mode="aggregate") + effRag = buildEffectiveByWorkspaceFds(records, "ragIndexEnabled", mode="aggregate") + for fds in records: + fdsId = fds.get("id", "") + fds["effectiveNeutralize"] = effNeutralize.get(fdsId, False) + fds["effectiveScope"] = effScope.get(fdsId, "personal") + fds["effectiveRagIndexEnabled"] = effRag.get(fdsId, False) + + return JSONResponse({"featureDataSources": records}) @router.delete("/{instanceId}/feature-datasources/{featureDataSourceId}") @@ -1770,112 +1554,6 @@ async def deleteFeatureDataSource( return JSONResponse({"success": True}) -@router.get("/{instanceId}/connections/{connectionId}/services") -@limiter.limit("120/minute") -async def listConnectionServices( - request: Request, - instanceId: str = Path(...), - connectionId: str = Path(...), - context: RequestContext = Depends(getRequestContext), -): - """Return the available services for a specific UserConnection.""" - _mandateId, _ = _validateInstanceAccess(instanceId, context) - try: - from modules.connectors.connectorResolver import ConnectorResolver - from modules.serviceCenter import getService as getSvc - from modules.serviceCenter.context import ServiceCenterContext - ctx = ServiceCenterContext( - user=context.user, - mandate_id=_mandateId or "", - feature_instance_id=instanceId, - ) - chatService = getSvc("chat", ctx) - securityService = getSvc("security", ctx) - dbInterface = _buildResolverDbInterface(chatService) - resolver = ConnectorResolver(securityService, dbInterface) - provider = await resolver.resolve(connectionId) - services = provider.getAvailableServices() - _serviceLabels = { - "sharepoint": "SharePoint", - "outlook": "Outlook", - "teams": "Teams", - "onedrive": "OneDrive", - "drive": "Google Drive", - "gmail": "Gmail", - "files": "Files (FTP)", - "kdrive": "kDrive", - "calendar": "Calendar", - "contact": "Contacts", - } - _serviceIcons = { - "sharepoint": "sharepoint", - "outlook": "mail", - "teams": "chat", - "onedrive": "cloud", - "drive": "cloud", - "gmail": "mail", - "files": "folder", - "kdrive": "cloud", - "calendar": "calendar", - "contact": "contact", - } - items = [ - { - "service": s, - "label": _serviceLabels.get(s, s), - "icon": _serviceIcons.get(s, "folder"), - } - for s in services - ] - return JSONResponse({"services": items}) - except Exception as e: - logger.error(f"Error listing services for connection {connectionId}: {e}") - return JSONResponse({"services": [], "error": str(e)}, status_code=400) - - -@router.get("/{instanceId}/connections/{connectionId}/browse") -@limiter.limit("300/minute") -async def browseConnectionService( - request: Request, - instanceId: str = Path(...), - connectionId: str = Path(...), - service: str = Query(..., description="Service name (e.g. sharepoint, onedrive, outlook)"), - path: str = Query("/", description="Path within the service to browse"), - context: RequestContext = Depends(getRequestContext), -): - """Browse folders/items within a connection's service at a given path.""" - _mandateId, _ = _validateInstanceAccess(instanceId, context) - try: - from modules.connectors.connectorResolver import ConnectorResolver - from modules.serviceCenter import getService as getSvc - from modules.serviceCenter.context import ServiceCenterContext - ctx = ServiceCenterContext( - user=context.user, - mandate_id=_mandateId or "", - feature_instance_id=instanceId, - ) - chatService = getSvc("chat", ctx) - securityService = getSvc("security", ctx) - dbInterface = _buildResolverDbInterface(chatService) - resolver = ConnectorResolver(securityService, dbInterface) - adapter = await resolver.resolveService(connectionId, service) - entries = await adapter.browse(path, filter=None) - items = [] - for entry in (entries or []): - items.append({ - "name": entry.name, - "path": entry.path, - "isFolder": entry.isFolder, - "size": entry.size, - "mimeType": entry.mimeType, - "metadata": entry.metadata if hasattr(entry, "metadata") else {}, - }) - return JSONResponse({"items": items, "path": path, "service": service}) - except Exception as e: - logger.error(f"Error browsing {service} for connection {connectionId} at '{path}': {e}") - return JSONResponse({"items": [], "error": str(e)}, status_code=400) - - # --------------------------------------------------------------------------- # Voice endpoints # --------------------------------------------------------------------------- @@ -2191,6 +1869,71 @@ async def putWorkspaceUserSettings( }) +# ========================================================================= +# Per-user UI state: tree expand/collapse (UDB + FilesTab) +# Persisted on WorkspaceUserSettings.uiTreeExpansion as a {scope: [ids]} map. +# Each FE tab uses its own scope key so collapse-state for one tab doesn't +# bleed into another. + +@router.get("/{instanceId}/ui-tree-expansion/{scope}") +@limiter.limit("300/minute") +async def getUiTreeExpansion( + request: Request, + instanceId: str = Path(...), + scope: str = Path(..., description="UI scope key, e.g. 'sources', 'filesOwn', 'filesShared'"), + context: RequestContext = Depends(getRequestContext), +): + """Return the expanded tree-node ids for the current user + scope. + + Returns `null` when the user has never persisted a state for this scope + (lets the FE fall back to backend `defaultExpanded` hints). Returns `[]` + when the user actively collapsed everything. + """ + _validateInstanceAccess(instanceId, context) + wsInterface = _getWorkspaceInterface(context, instanceId) + settings = wsInterface.getWorkspaceUserSettings(str(context.user.id)) + expansion = (settings.uiTreeExpansion if settings else {}) or {} + if scope not in expansion: + return JSONResponse({"expandedNodes": None}) + return JSONResponse({"expandedNodes": list(expansion.get(scope) or [])}) + + +@router.put("/{instanceId}/ui-tree-expansion/{scope}") +@limiter.limit("300/minute") +async def putUiTreeExpansion( + request: Request, + instanceId: str = Path(...), + scope: str = Path(...), + body: dict = Body(...), + context: RequestContext = Depends(getRequestContext), +): + """Replace the expanded-node list for one scope. + + Body: `{"expandedNodes": List[str]}`. Empty list = explicit collapse-all. + """ + _validateInstanceAccess(instanceId, context) + wsInterface = _getWorkspaceInterface(context, instanceId) + userId = str(context.user.id) + nodes = body.get("expandedNodes") + if not isinstance(nodes, list): + raise HTTPException(status_code=400, detail=routeApiMsg("expandedNodes must be a list")) + cleaned = [str(n) for n in nodes if isinstance(n, (str, int))] + + existing = wsInterface.getWorkspaceUserSettings(userId) + existingMap: Dict[str, List[str]] = (existing.uiTreeExpansion if existing else {}) or {} + existingMap = dict(existingMap) + existingMap[scope] = cleaned + + data = { + "userId": userId, + "mandateId": str(context.mandateId) if context.mandateId else "", + "featureInstanceId": instanceId, + "uiTreeExpansion": existingMap, + } + wsInterface.saveWorkspaceUserSettings(data) + return JSONResponse({"expandedNodes": cleaned}) + + # ========================================================================= # RAG / Knowledge — anonymised instance statistics (presentation / KPIs) diff --git a/modules/routes/routeAdminDemoConfig.py b/modules/routes/routeAdminDemoConfig.py index db37e775..0673c299 100644 --- a/modules/routes/routeAdminDemoConfig.py +++ b/modules/routes/routeAdminDemoConfig.py @@ -68,9 +68,19 @@ def removeDemoConfig( request: Request, currentUser: User = Depends(requirePlatformAdmin), ) -> dict: - """Remove all data created by a demo configuration.""" + """Remove all data created by a demo configuration. + + Requires X-Confirm-Destructive: true header as safety guard. + """ from modules.demoConfigs import getDemoConfigByCode + confirmHeader = request.headers.get("X-Confirm-Destructive", "").lower() + if confirmHeader != "true": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Destructive operation requires header X-Confirm-Destructive: true", + ) + config = getDemoConfigByCode(code) if not config: raise HTTPException( @@ -79,7 +89,7 @@ def removeDemoConfig( ) db = getRootDbAppConnector() - logger.info(f"Removing demo config '{code}' (user: {currentUser.username})") + logger.info(f"Removing demo config '{code}' (user: {currentUser.username}, confirmed)") summary = config.remove(db) logger.info(f"Demo config '{code}' removed: {summary}") diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py index e2b08461..2bc48042 100644 --- a/modules/routes/routeDataConnections.py +++ b/modules/routes/routeDataConnections.py @@ -778,7 +778,12 @@ async def _updateKnowledgeConsent( cancelled = cancelJobsByConnection(connectionId) else: from modules.datamodels.datamodelDataSource import DataSource - dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True}) + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag + allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + dataSources = [ + ds for ds in (allConnDs or []) + if getEffectiveFlag(ds, "ragIndexEnabled", allConnDs, mode="walk") is True + ] if dataSources: from modules.serviceCenter.services.serviceBackgroundJobs import startJob authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "") diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py index b22dacae..4bcbcf8f 100644 --- a/modules/routes/routeDataFiles.py +++ b/modules/routes/routeDataFiles.py @@ -211,7 +211,7 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob - await knowledgeService.requestIngestion( + handle = await knowledgeService.requestIngestion( IngestionJob( sourceKind="file", sourceId=fileId, @@ -229,7 +229,10 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man # Re-acquire interface after await to avoid stale user context from the singleton mgmtInterface = interfaceDbManagement.getInterface(user) mgmtInterface.updateFile(fileId, {"status": "active"}) - logger.info(f"Auto-index complete for file {fileId} ({fileName})") + if handle.status == "failed": + logger.warning(f"Auto-index ingestion failed for file {fileId} ({fileName}): {handle.error}") + else: + logger.info(f"Auto-index complete for file {fileId} ({fileName})") except Exception as e: logger.error(f"Auto-index failed for file {fileId}: {e}", exc_info=True) @@ -256,6 +259,24 @@ router = APIRouter( ) +def _getInterfaceForOwnedItem(currentUser: User, context, itemId: str, modelClass) -> Any: + """Create a management interface scoped to the item's own context. + Looks up the item by ID (unscoped) to resolve its mandateId/featureInstanceId, + then creates the interface with THAT context. This ensures toggle operations + work regardless of which page the user is on.""" + unscoped = interfaceDbManagement.getInterface(currentUser) + record = unscoped.db.getRecord(modelClass, itemId) + if not record: + raise interfaceDbManagement.FileNotFoundError(f"Item {itemId} not found") + itemMandateId = record.get("mandateId") if isinstance(record, dict) else getattr(record, "mandateId", None) + itemInstanceId = record.get("featureInstanceId") if isinstance(record, dict) else getattr(record, "featureInstanceId", None) + return interfaceDbManagement.getInterface( + currentUser, + mandateId=str(itemMandateId) if itemMandateId else None, + featureInstanceId=str(itemInstanceId) if itemInstanceId else None, + ) + + @router.get("/folders/tree") @limiter.limit("120/minute") def get_folder_tree( @@ -272,10 +293,12 @@ def get_folder_tree( ) o = (owner or "me").strip().lower() if o == "me": - return managementInterface.getOwnFolderTree() - if o == "shared": - return managementInterface.getSharedFolderTree() - raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'") + folders = managementInterface.getOwnFolderTree() + elif o == "shared": + folders = managementInterface.getSharedFolderTree() + else: + raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'") + return folders except HTTPException: raise except Exception as e: @@ -283,6 +306,185 @@ def get_folder_tree( raise HTTPException(status_code=500, detail=str(e)) +@router.post("/attributes") +@limiter.limit("120/minute") +def getAttributesForIds( + request: Request, + body: Dict[str, Any] = Body(...), + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext), +): + """Return current attribute values (neutralize, scope, ragIndexEnabled) for + a list of node IDs. For folder IDs, computes 'mixed' by checking direct + children. The frontend sends this after every toggle to refresh visible + nodes without reloading the tree structure.""" + ids = body.get("ids", []) + if not isinstance(ids, list) or len(ids) == 0: + return {} + if len(ids) > 500: + raise HTTPException(status_code=400, detail="Max 500 IDs per request") + + try: + managementInterface = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, + ) + db = managementInterface.db + userId = str(currentUser.id) + + allFolders = db.getRecordset(FileFolder, recordFilter={"sysCreatedBy": userId}) or [] + allFiles = db.getRecordset(FileItem, recordFilter={"sysCreatedBy": userId}) or [] + + folderById = {f["id"]: f for f in allFolders} + fileById = {f["id"]: f for f in allFiles} + + logger.info( + "getAttributesForIds: %d ids requested, %d folders found, %d files found", + len(ids), len(allFolders), len(allFiles), + ) + + result: Dict[str, Dict[str, Any]] = {} + + for nodeId in ids: + if nodeId.startswith("__filesRoot:"): + attrs = _computeSyntheticRootAttrs(allFolders, allFiles) + result[nodeId] = attrs + elif nodeId in folderById: + folder = folderById[nodeId] + attrs = _computeFolderAttrs(folder, allFolders, allFiles) + result[nodeId] = attrs + elif nodeId in fileById: + f = fileById[nodeId] + result[nodeId] = { + "neutralize": bool(f.get("neutralize", False)), + "scope": f.get("scope", "personal"), + } + else: + logger.debug("getAttributesForIds: unknown id=%s", nodeId) + + logger.info("getAttributesForIds: returning %d entries", len(result)) + return result + except HTTPException: + raise + except Exception as e: + logger.error(f"getAttributesForIds error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +def _computeFolderAttrs( + folder: Dict[str, Any], + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Compute attributes for a folder. Recursively checks the entire subtree: + if ANY descendant at any depth has a different value, the folder shows 'mixed'. + This propagates up through all ancestor levels.""" + fid = folder["id"] + neutralizeResult = _effectiveNeutralize(fid, allFolders, allFiles) + scopeResult = _effectiveScope(fid, allFolders, allFiles) + return {"neutralize": neutralizeResult, "scope": scopeResult} + + +def _effectiveNeutralize( + folderId: str, + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Any: + """Recursively compute effective neutralize for a folder. + Returns 'mixed' if any descendants diverge, otherwise the folder's own value.""" + childFolders = [f for f in allFolders if f.get("parentId") == folderId] + childFiles = [f for f in allFiles if f.get("folderId") == folderId] + + if not childFolders and not childFiles: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return bool(folder.get("neutralize", False)) if folder else False + + childVals = set() + for cf in childFolders: + effective = _effectiveNeutralize(cf["id"], allFolders, allFiles) + if effective == "mixed": + return "mixed" + childVals.add(effective) + for cf in childFiles: + childVals.add(bool(cf.get("neutralize", False))) + + if len(childVals) > 1: + return "mixed" + if not childVals: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return bool(folder.get("neutralize", False)) if folder else False + return childVals.pop() + + +def _effectiveScope( + folderId: str, + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Any: + """Recursively compute effective scope for a folder. + Returns 'mixed' if any descendants diverge, otherwise the folder's own value.""" + childFolders = [f for f in allFolders if f.get("parentId") == folderId] + childFiles = [f for f in allFiles if f.get("folderId") == folderId] + + if not childFolders and not childFiles: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return folder.get("scope", "personal") if folder else "personal" + + childVals = set() + for cf in childFolders: + effective = _effectiveScope(cf["id"], allFolders, allFiles) + if effective == "mixed": + return "mixed" + childVals.add(effective) + for cf in childFiles: + childVals.add(cf.get("scope", "personal")) + + if len(childVals) > 1: + return "mixed" + if not childVals: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return folder.get("scope", "personal") if folder else "personal" + return childVals.pop() + + +def _computeSyntheticRootAttrs( + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Compute attributes for the synthetic root by recursively checking the + entire tree. If ANY item at any depth diverges, root shows 'mixed'.""" + topFolders = [f for f in allFolders if not f.get("parentId")] + topFiles = [f for f in allFiles if not f.get("folderId")] + + neutralizeVals = set() + scopeVals = set() + for cf in topFolders: + nEff = _effectiveNeutralize(cf["id"], allFolders, allFiles) + if nEff == "mixed": + neutralizeVals.add(True) + neutralizeVals.add(False) + else: + neutralizeVals.add(nEff) + sEff = _effectiveScope(cf["id"], allFolders, allFiles) + if sEff == "mixed": + scopeVals.add("__mixed_a__") + scopeVals.add("__mixed_b__") + else: + scopeVals.add(sEff) + for cf in topFiles: + neutralizeVals.add(bool(cf.get("neutralize", False))) + scopeVals.add(cf.get("scope", "personal")) + + if not neutralizeVals and not scopeVals: + return {"neutralize": False, "scope": "personal"} + + return { + "neutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "scope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + } + + @router.post("/folders", status_code=status.HTTP_201_CREATED) @limiter.limit("30/minute") def create_folder( @@ -353,7 +555,12 @@ def move_folder( context: RequestContext = Depends(getRequestContext), ): try: + # FE may send `parentId` or `targetParentId`. Accept both so the + # FormGeneratorTree generic `provider.moveNodes(targetParentId)` API + # remains consistent with the file-move (PUT /api/files/{id}) shape. newParentId = body.get("parentId") + if newParentId is None: + newParentId = body.get("targetParentId") managementInterface = interfaceDbManagement.getInterface( currentUser, mandateId=str(context.mandateId) if context.mandateId else None, @@ -414,11 +621,7 @@ def patch_folder_scope( if not scope: raise HTTPException(status_code=400, detail="scope is required") cascadeToFiles = body.get("cascadeChildren", body.get("cascadeToFiles", False)) - managementInterface = interfaceDbManagement.getInterface( - currentUser, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder) return managementInterface.patchFolderScope(folderId, scope, cascadeToFiles) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -446,11 +649,7 @@ def patch_folder_neutralize( neutralize = body.get("neutralize") if neutralize is None: raise HTTPException(status_code=400, detail="neutralize is required") - managementInterface = interfaceDbManagement.getInterface( - currentUser, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder) return managementInterface.patchFolderNeutralize(folderId, bool(neutralize)) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @@ -1031,11 +1230,7 @@ def updateFileScope( if scope == "global" and not context.isSysAdmin: raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope")) - managementInterface = interfaceDbManagement.getInterface( - context.user, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem) managementInterface.updateFile(fileId, {"scope": scope}) @@ -1093,11 +1288,7 @@ def updateFileNeutralize( fails the file simply has no index — no un-neutralized data can leak. """ try: - managementInterface = interfaceDbManagement.getInterface( - context.user, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem) managementInterface.updateFile(fileId, {"neutralize": neutralize}) @@ -1212,7 +1403,8 @@ def update_file( request: Request, fileId: str = Path(..., description="ID of the file to update"), file_info: Dict[str, Any] = Body(...), - currentUser: User = Depends(getCurrentUser) + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext), ) -> FileItem: """Update file info""" try: @@ -1221,7 +1413,11 @@ def update_file( if not safeData: raise HTTPException(status_code=400, detail=routeApiMsg("No editable fields provided")) - managementInterface = interfaceDbManagement.getInterface(currentUser) + managementInterface = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, + ) file = managementInterface.getFile(fileId) if not file: @@ -1267,10 +1463,15 @@ def update_file( def delete_file( request: Request, fileId: str = Path(..., description="ID of the file to delete"), - currentUser: User = Depends(getCurrentUser) + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext), ) -> Dict[str, Any]: """Delete a file""" - managementInterface = interfaceDbManagement.getInterface(currentUser) + managementInterface = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, + ) # Check if the file exists existingFile = managementInterface.getFile(fileId) diff --git a/modules/routes/routeDataSources.py b/modules/routes/routeDataSources.py index 5dec19c8..b2f919b7 100644 --- a/modules/routes/routeDataSources.py +++ b/modules/routes/routeDataSources.py @@ -43,6 +43,49 @@ def _ensureConnectionKnowledgeFlag(rootIf, connectionId: str) -> None: except Exception as e: logger.warning("Could not auto-enable knowledgeIngestionEnabled for connection %s: %s", connectionId, e) +def _computeOwnEffective(rootIf, rec, model, sourceId: str, flag: str) -> Any: + """Re-load the record after modification and compute its aggregate effective value.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + getEffectiveFlag, getEffectiveFlagFds, + ) + freshRec = rootIf.db.getRecord(model, sourceId) + if not freshRec: + return None + if model is DataSource: + connectionId = freshRec.get("connectionId", "") + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + return getEffectiveFlag(freshRec, flag, allDs, mode="aggregate") + else: + wsId = freshRec.get("workspaceInstanceId", "") + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId}) + return getEffectiveFlagFds(freshRec, flag, allFds, mode="aggregate") + + +def _computeAncestorEffectives(rootIf, rec, model, flag: str) -> List[Dict[str, Any]]: + """Compute the aggregate effective value for all ancestors of `rec`.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + collectAncestorChain, collectAncestorChainFds, + getEffectiveFlag, getEffectiveFlagFds, + ) + effectiveKey = f"effective{flag[0].upper()}{flag[1:]}" + if model is DataSource: + connectionId = rec.get("connectionId", "") + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + ancestors = collectAncestorChain(rec, allDs) + return [ + {"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlag(a, flag, allDs, mode="aggregate")} + for a in ancestors + ] + else: + wsId = rec.get("workspaceInstanceId", "") + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId}) + ancestors = collectAncestorChainFds(rec, allFds) + return [ + {"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlagFds(a, flag, allFds, mode="aggregate")} + for a in ancestors + ] + + router = APIRouter( prefix="/api/datasources", tags=["Data Sources"], @@ -91,26 +134,41 @@ def _updateDataSourceScope( try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( - cascadeResetDescendants, - cascadeResetDescendantsFds, + cascadeResetDescendants, cascadeResetDescendantsFds, + getEffectiveFlag, getEffectiveFlagFds, + collectAncestorChain, collectAncestorChainFds, ) rootIf = getRootInterface() rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") - rootIf.db.recordModify(model, sourceId, {"scope": scope}) - cascaded = 0 + # 1. Cascade reset descendants bottom-up (before modifying master) + resetIds: List[str] = [] if scope is not None: if model is DataSource: - cascaded = cascadeResetDescendants(rootIf, rec, "scope") + resetIds = cascadeResetDescendants(rootIf, rec, "scope") else: - cascaded = cascadeResetDescendantsFds(rootIf, rec, "scope") + resetIds = cascadeResetDescendantsFds(rootIf, rec, "scope") + + # 2. Set master value last (crash-safe) + rootIf.db.recordModify(model, sourceId, {"scope": scope}) + + # 3. Compute effective + ancestor chain for response + updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "scope") + effectiveScope = _computeOwnEffective(rootIf, rec, model, sourceId, "scope") + logger.info( "Updated scope=%s for %s %s (cascade-reset %d descendants)", - scope, model.__name__, sourceId, cascaded, + scope, model.__name__, sourceId, len(resetIds), ) - return {"sourceId": sourceId, "scope": scope, "updated": True, "cascadedDescendants": cascaded} + return { + "sourceId": sourceId, + "scope": scope, + "effectiveScope": effectiveScope, + "resetDescendantIds": resetIds, + "updatedAncestors": updatedAncestors, + } except HTTPException: raise except Exception as e: @@ -133,26 +191,39 @@ def _updateDataSourceNeutralize( try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( - cascadeResetDescendants, - cascadeResetDescendantsFds, + cascadeResetDescendants, cascadeResetDescendantsFds, ) rootIf = getRootInterface() rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") - rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize}) - cascaded = 0 + # 1. Cascade reset descendants bottom-up (before modifying master) + resetIds: List[str] = [] if neutralize is not None: if model is DataSource: - cascaded = cascadeResetDescendants(rootIf, rec, "neutralize") + resetIds = cascadeResetDescendants(rootIf, rec, "neutralize") else: - cascaded = cascadeResetDescendantsFds(rootIf, rec, "neutralize") + resetIds = cascadeResetDescendantsFds(rootIf, rec, "neutralize") + + # 2. Set master value last (crash-safe) + rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize}) + + # 3. Compute effective + ancestor chain for response + updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "neutralize") + effectiveNeutralize = _computeOwnEffective(rootIf, rec, model, sourceId, "neutralize") + logger.info( "Updated neutralize=%s for %s %s (cascade-reset %d descendants)", - neutralize, model.__name__, sourceId, cascaded, + neutralize, model.__name__, sourceId, len(resetIds), ) - return {"sourceId": sourceId, "neutralize": neutralize, "updated": True, "cascadedDescendants": cascaded} + return { + "sourceId": sourceId, + "neutralize": neutralize, + "effectiveNeutralize": effectiveNeutralize, + "resetDescendantIds": resetIds, + "updatedAncestors": updatedAncestors, + } except HTTPException: raise except Exception as e: @@ -204,46 +275,57 @@ async def _updateDataSourceRagIndex( `True` enqueues a mini-bootstrap. `False` synchronously purges chunks. Must be `async def` so `await startJob(...)` registers `_runJob` in the - main event loop. Sync route → worker thread → temporary loop closes - before the task runs → job stays stuck forever. + main event loop. """ try: from modules.interfaces.interfaceDbApp import getRootInterface - from modules.serviceCenter.services.serviceKnowledge._inheritFlags import cascadeResetDescendants + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + cascadeResetDescendants, cascadeResetDescendantsFds, + ) rootIf = getRootInterface() - rec = rootIf.db.getRecord(DataSource, sourceId) + rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") - rootIf.db.recordModify(DataSource, sourceId, {"ragIndexEnabled": ragIndexEnabled}) - cascaded = 0 + # 1. Cascade reset descendants bottom-up (before modifying master) + resetIds: List[str] = [] if ragIndexEnabled is not None: - cascaded = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled") + if model is DataSource: + resetIds = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled") + else: + resetIds = cascadeResetDescendantsFds(rootIf, rec, "ragIndexEnabled") + + # 2. Set master value last (crash-safe) + rootIf.db.recordModify(model, sourceId, {"ragIndexEnabled": ragIndexEnabled}) + logger.info( - "Updated ragIndexEnabled=%s for DataSource %s (cascade-reset %d descendants)", - ragIndexEnabled, sourceId, cascaded, + "Updated ragIndexEnabled=%s for %s %s (cascade-reset %d descendants)", + ragIndexEnabled, model.__name__, sourceId, len(resetIds), ) - connectionId = rec.get("connectionId") or rec.get("connection_id") or "" - if ragIndexEnabled is True: - _ensureConnectionKnowledgeFlag(rootIf, connectionId) - from modules.serviceCenter.services.serviceBackgroundJobs import startJob + # Bootstrap / purge only for personal DataSource (file/folder-based RAG). + # FDS RAG is handled by the feature pipeline; the flag alone is enough. + if model is DataSource: + connectionId = rec.get("connectionId") or rec.get("connection_id") or "" + if ragIndexEnabled is True: + _ensureConnectionKnowledgeFlag(rootIf, connectionId) + from modules.serviceCenter.services.serviceBackgroundJobs import startJob - conn = rootIf.getUserConnectionById(connectionId) if connectionId else None - authority = "" - if conn: - authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "") + conn = rootIf.getUserConnectionById(connectionId) if connectionId else None + authority = "" + if conn: + authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "") - await startJob( - "connection.bootstrap", - {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]}, - triggeredBy=str(context.user.id), - ) - elif ragIndexEnabled is False: - from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface - purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId) - logger.info("Purged %d index rows / %d chunks for DataSource %s", - purgeResult.get("indexRows", 0), purgeResult.get("chunks", 0), sourceId) + await startJob( + "connection.bootstrap", + {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]}, + triggeredBy=str(context.user.id), + ) + elif ragIndexEnabled is False: + from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface + purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId) + logger.info("Purged %d index rows / %d chunks for DataSource %s", + purgeResult.get("indexRows", 0), purgeResult.get("chunks", 0), sourceId) import json from modules.shared.auditLogger import audit_logger @@ -253,10 +335,20 @@ async def _updateDataSourceRagIndex( mandateId=context.mandateId, category=AuditCategory.PERMISSION.value, action="rag_index_toggled", - details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "cascadedDescendants": cascaded}), + details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "resetDescendants": len(resetIds), "model": model.__name__}), ) - return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True, "cascadedDescendants": cascaded} + # 3. Compute effective + ancestors for response + updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "ragIndexEnabled") + effectiveRag = _computeOwnEffective(rootIf, rec, model, sourceId, "ragIndexEnabled") + + return { + "sourceId": sourceId, + "ragIndexEnabled": ragIndexEnabled, + "effectiveRagIndexEnabled": effectiveRag, + "resetDescendantIds": resetIds, + "updatedAncestors": updatedAncestors, + } except HTTPException: raise except Exception as e: @@ -339,7 +431,17 @@ def _updateDataSourceSettings( ownerId = str(rec.get("userId") or "") currentUserId = str(context.user.id) if ownerId and ownerId != currentUserId and not context.isSysAdmin: - scope = str(rec.get("scope") or "personal") + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag + if model is DataSource: + connectionId = rec.get("connectionId", "") + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + scope = str(getEffectiveFlag(rec, "scope", allDs, mode="walk")) + else: + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource as FDS + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + wsId = rec.get("workspaceInstanceId", "") + allFds = rootIf.db.getRecordset(FDS, recordFilter={"workspaceInstanceId": wsId}) + scope = str(getEffectiveFlagFds(rec, "scope", allFds, mode="walk")) isMandateAdmin = getattr(context, "isMandateAdmin", False) if scope == "personal" or not isMandateAdmin: raise HTTPException(status_code=403, detail="Not allowed to modify this DataSource's settings") diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py index 99d5c4df..6a5e9eb5 100644 --- a/modules/routes/routeRagInventory.py +++ b/modules/routes/routeRagInventory.py @@ -86,6 +86,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L """ from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelKnowledge import FileContentIndex + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag out = [] for conn in connections: @@ -136,8 +137,8 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""), "path": dsPath, "sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""), - "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), - "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), + "ragIndexEnabled": getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk"), + "neutralize": getEffectiveFlag(ds, "neutralize", dataSources, mode="walk"), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), "fileCount": filesByDs.get(dsId, 0), "chunkCount": chunksByDs.get(dsId, 0), @@ -223,13 +224,165 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L return out +def _buildFeatureInstanceInventory(featureInstanceIds, rootIf, knowledgeIf) -> List[Dict[str, Any]]: + """Build per-feature-instance RAG inventory rows. + + Feature-instance data lives in FileContentIndex with a non-empty + featureInstanceId. Additionally each feature instance may have + FeatureDataSource rows that define which tables/data are visible + as sources, with their own ragIndexEnabled flags. + Includes feature.bootstrap job status (running/success/error). + """ + from modules.datamodels.datamodelKnowledge import FileContentIndex + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.interfaces.interfaceFeatures import getFeatureInterface + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService + from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE + + featureIf = getFeatureInterface(rootIf.db) + + allFeatureJobs = jobService.listJobs(jobType=FEATURE_BOOTSTRAP_JOB_TYPE, limit=100) + + out = [] + for fiId in featureInstanceIds: + instance = featureIf.getFeatureInstance(fiId) + if not instance or not instance.enabled: + continue + + indexRows = knowledgeIf.db.getRecordset( + FileContentIndex, recordFilter={"featureInstanceId": fiId} + ) + fileIds = [ + (r.get("id") if isinstance(r, dict) else getattr(r, "id", "")) + for r in indexRows + ] + fileIds = [fid for fid in fileIds if fid] + chunkCounts = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {} + + statusCounts: Dict[str, int] = {} + for r in indexRows: + st = (r.get("status") if isinstance(r, dict) else getattr(r, "status", "unknown")) or "unknown" + statusCounts[st] = statusCounts.get(st, 0) + 1 + + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": fiId}) + dsItems = [] + anyRagEnabled = False + for fds in allFds: + tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or "" + fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or "" + if tblName == "*" or not fCode: + continue + fdsId = fds.get("id") if isinstance(fds, dict) else getattr(fds, "id", "") + ragEnabled = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate") + if ragEnabled: + anyRagEnabled = True + dsItems.append({ + "id": fdsId, + "label": (fds.get("label") if isinstance(fds, dict) else getattr(fds, "label", "")) or "", + "tableName": tblName, + "featureCode": fCode, + "ragIndexEnabled": ragEnabled, + }) + + fiJobs = [ + j for j in allFeatureJobs + if (j.get("payload") or {}).get("workspaceInstanceId") == fiId + ] + runningJobs = [ + { + "jobId": j["id"], + "progress": j.get("progress", 0), + "progressMessage": ( + resolveJobMessage(j.get("progressMessageData")) + or j.get("progressMessage", "") + ), + } + for j in fiJobs + if j.get("status") in ("PENDING", "RUNNING") + ] + lastError: Optional[Dict[str, Any]] = None + lastSuccess: Optional[Dict[str, Any]] = None + for j in fiJobs: + jStatus = j.get("status") + if jStatus == "ERROR" and lastError is None: + lastError = { + "jobId": j["id"], + "errorMessage": j.get("errorMessage", ""), + "finishedAt": j.get("finishedAt"), + } + elif jStatus == "SUCCESS" and lastSuccess is None: + result = j.get("result") or {} + lastSuccess = { + "jobId": j["id"], + "finishedAt": j.get("finishedAt"), + "indexed": result.get("indexed", 0), + "skippedDuplicate": result.get("skippedDuplicate", 0), + "failed": result.get("failed", 0), + } + if lastError and lastSuccess: + break + + if not indexRows and not dsItems: + continue + + out.append({ + "featureInstanceId": fiId, + "featureCode": instance.featureCode, + "label": instance.label or instance.featureCode, + "mandateId": str(instance.mandateId) if instance.mandateId else "", + "fileCount": len(indexRows), + "chunkCount": sum(chunkCounts.values()), + "statusCounts": statusCounts, + "dataSources": dsItems, + "ragEnabled": anyRagEnabled, + "runningJobs": runningJobs, + "lastSuccess": lastSuccess, + "lastError": lastError, + }) + return out + + +@router.get("/my-mandates") +@limiter.limit("30/minute") +def _getMyMandates( + request: Request, + currentUser: User = Depends(getCurrentUser), +) -> List[Dict[str, Any]]: + """Return mandates where the current user has an active membership. + + Used by the RAG inventory frontend to populate the mandate dropdown + without requiring admin rights (unlike GET /api/mandates/). + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + rootIf = getRootInterface() + userMandates = rootIf.getUserMandates(str(currentUser.id)) + result = [] + for um in userMandates: + if not um.enabled: + continue + mandate = rootIf.getMandate(str(um.mandateId)) + if not mandate or not getattr(mandate, "enabled", True): + continue + result.append({ + "id": str(um.mandateId), + "name": getattr(mandate, "name", ""), + "label": getattr(mandate, "label", None) or getattr(mandate, "name", ""), + }) + return result + except Exception as e: + logger.error("Error in RAG inventory /my-mandates: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/me") @limiter.limit("30/minute") def _getInventoryMe( request: Request, currentUser: User = Depends(getCurrentUser), ) -> Dict[str, Any]: - """Personal RAG inventory: own connections + DataSources + chunk counts.""" + """Personal RAG inventory: own connections + DataSources + chunk counts + feature uploads.""" try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface @@ -243,7 +396,20 @@ def _getInventoryMe( totalChunks = sum(c.get("totalChunks", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} + featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id)) + fiIds = [ + str(fa.featureInstanceId) for fa in featureAccesses + if fa.enabled and fa.featureInstanceId + ] + fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf) + totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems) + totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems) + + return { + "connections": items, + "featureInstances": fiItems, + "totals": {"files": totalFiles, "chunks": totalChunks}, + } except Exception as e: logger.error("Error in RAG inventory /me: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @@ -262,21 +428,43 @@ def _getInventoryMandate( from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService - rootIf = getRootInterface() knowledgeIf = getKnowledgeInterface(None) - mandateId = str(context.mandateId) if context.mandateId else "" + mandateId = str(context.mandateId) + userId = str(context.user.id) - from modules.datamodels.datamodelUam import UserConnection - allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId}) - connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections] + userMandates = rootIf.getUserMandates(userId) + isMember = any( + getattr(um, "mandateId", None) == mandateId and um.enabled + for um in userMandates + ) + if not isMember and not context.isSysAdmin: + raise HTTPException(status_code=403, detail=routeApiMsg("No membership in this mandate")) - items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) + mandateMembers = rootIf.getUserMandatesByMandate(mandateId) + memberUserIds = {getattr(um, "userId", None) for um in mandateMembers} + memberUserIds.discard(None) + + allConnections = [] + for uid in memberUserIds: + allConnections.extend(rootIf.getUserConnections(uid)) + + items = _buildConnectionInventory(allConnections, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items) totalBytes = aggregateMandateRagTotalBytes(mandateId) - return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}} + mandateInstances = rootIf.getFeatureInstancesByMandate(mandateId, enabledOnly=True) + fiIds = [str(inst.id) for inst in mandateInstances if inst.id] + fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf) + totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems) + totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems) + + return { + "connections": items, + "featureInstances": fiItems, + "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}, + } except HTTPException: raise except Exception as e: @@ -308,7 +496,22 @@ def _getInventoryPlatform( totalChunks = sum(c.get("totalChunks", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} + from modules.datamodels.datamodelFeatures import FeatureInstance + allInstances = rootIf.db.getRecordset(FeatureInstance, recordFilter={"enabled": True}) + fiIds = [ + (r.get("id") if isinstance(r, dict) else getattr(r, "id", "")) + for r in allInstances + ] + fiIds = [fid for fid in fiIds if fid] + fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf) + totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems) + totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems) + + return { + "connections": items, + "featureInstances": fiItems, + "totals": {"files": totalFiles, "chunks": totalChunks}, + } except HTTPException: raise except Exception as e: @@ -345,8 +548,9 @@ async def _reindexConnection( if str(conn.userId) != str(currentUser.id): raise HTTPException(status_code=403, detail="Not your connection") + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) - ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))] + ragDs = [ds for ds in dataSources if getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk") is True] if not ragDs: return {"status": "skipped", "reason": "no_rag_enabled_datasources"} @@ -368,6 +572,47 @@ async def _reindexConnection( raise HTTPException(status_code=500, detail=str(e)) +@router.post("/reindex-feature/{workspaceInstanceId}") +@limiter.limit("10/minute") +async def _reindexFeature( + request: Request, + workspaceInstanceId: str, + currentUser: User = Depends(getCurrentUser), +) -> Dict[str, Any]: + """Re-trigger feature data bootstrap for a workspace instance. + + Indexes all RAG-enabled FeatureDataSource rows into the knowledge store. + Must be ``async def`` so ``await startJob(...)`` registers in the main loop. + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceBackgroundJobs import startJob + from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE + + rootIf = getRootInterface() + featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id)) + hasAccess = any( + str(fa.featureInstanceId) == workspaceInstanceId and fa.enabled + for fa in featureAccesses + ) + if not hasAccess and not getattr(currentUser, "isSysAdmin", False): + raise HTTPException(status_code=403, detail="No access to this feature instance") + + jobId = await startJob( + FEATURE_BOOTSTRAP_JOB_TYPE, + {"workspaceInstanceId": workspaceInstanceId}, + triggeredBy=str(currentUser.id), + ) + + logger.info("Feature reindex triggered for workspace %s (jobId=%s)", workspaceInstanceId, jobId) + return {"status": "queued", "workspaceInstanceId": workspaceInstanceId, "jobId": jobId} + except HTTPException: + raise + except Exception as e: + logger.error("Error triggering feature reindex: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/jobs") @limiter.limit("60/minute") def _getActiveJobs( diff --git a/modules/security/rbac.py b/modules/security/rbac.py index bec0b70e..59f8f55f 100644 --- a/modules/security/rbac.py +++ b/modules/security/rbac.py @@ -341,11 +341,10 @@ class RbacClass: return [] try: - conn = self.dbApp.connection roleIds = set() - + # 1. Mandant-Rollen via UserMandate → UserMandateRole (SINGLE Query) - with conn.cursor() as cursor: + with self.dbApp.borrowCursor() as cursor: cursor.execute( """ SELECT umr."roleId" @@ -357,10 +356,10 @@ class RbacClass: ) mandateRoles = cursor.fetchall() roleIds.update(r["roleId"] for r in mandateRoles if r.get("roleId")) - + # 2. Instanz-Rollen via FeatureAccess → FeatureAccessRole (SINGLE Query) if featureInstanceId: - with conn.cursor() as cursor: + with self.dbApp.borrowCursor() as cursor: cursor.execute( """ SELECT far."roleId" @@ -372,14 +371,13 @@ class RbacClass: ) instanceRoles = cursor.fetchall() roleIds.update(r["roleId"] for r in instanceRoles if r.get("roleId")) - + if not roleIds: return [] - + # 3. BULK Query: Alle Regeln für alle Rollen + zugehörige Role-Daten - # SINGLE Query mit JOIN statt N+1 roleIdsList = list(roleIds) - with conn.cursor() as cursor: + with self.dbApp.borrowCursor() as cursor: cursor.execute( """ SELECT ar.*, r."mandateId" as "roleMandateId", diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py index fff1bcb3..dbd28dd4 100644 --- a/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py +++ b/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py @@ -67,7 +67,12 @@ def _registerDataSourceTools(registry: ToolRegistry, services): sourceType = ds.get("sourceType", "") path = ds.get("path", "/") label = ds.get("label", "") - neutralize = bool(ds.get("neutralize", False)) + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag + from modules.datamodels.datamodelDataSource import DataSource + from modules.interfaces.interfaceDbApp import getRootInterface + rootIf = getRootInterface() + allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + neutralize = bool(getEffectiveFlag(ds, "neutralize", allConnDs or [ds], mode="walk")) service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType) if not connectionId: raise ValueError(f"DataSource '{dsId}' has no connectionId") diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py index bdb3d23b..2ebc2720 100644 --- a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py +++ b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py @@ -110,9 +110,11 @@ def _registerFeatureSubAgentTools(registry: ToolRegistry, services): recordFilter={"featureInstanceId": featureInstanceId, "workspaceInstanceId": workspaceInstanceId}, ) + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + _fdsAll = featureDataSources or [] _anySourceNeutralize = any( - bool(ds.get("neutralize", False) if isinstance(ds, dict) else getattr(ds, "neutralize", False)) - for ds in (featureDataSources or []) + getEffectiveFlagFds(ds, "neutralize", _fdsAll, mode="walk") is True + for ds in _fdsAll ) neutralizeFieldsPerTable: Dict[str, List[str]] = {} diff --git a/modules/serviceCenter/services/serviceAgent/featureDataProvider.py b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py index d7707bdf..27ec36b2 100644 --- a/modules/serviceCenter/services/serviceAgent/featureDataProvider.py +++ b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py @@ -95,8 +95,7 @@ class FeatureDataProvider: def getActualColumns(self, tableName: str) -> List[str]: """Read real column names from PostgreSQL information_schema.""" try: - conn = self._db.connection - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: cur.execute( "SELECT column_name FROM information_schema.columns " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " @@ -131,7 +130,6 @@ class FeatureDataProvider: Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``. """ _validateTableName(tableName) - conn = self._db.connection if fields: invalid = [f for f in fields if not _isValidIdentifier(f)] @@ -141,7 +139,7 @@ class FeatureDataProvider: "error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.", } - scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db) extraWhere, extraParams = _buildFilterClauses(extraFilters) fullWhere = scopeFilter["where"] @@ -152,7 +150,7 @@ class FeatureDataProvider: t0 = time.time() try: - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' cur.execute(countSql, allParams) total = cur.fetchone()["count"] if cur.rowcount else 0 @@ -179,10 +177,6 @@ class FeatureDataProvider: _debugQueryLog("browseTable", tableName, { "fields": fields, "limit": limit, "offset": offset, }, errResult, elapsed) - try: - conn.rollback() - except Exception: - pass return errResult def aggregateTable( @@ -208,8 +202,7 @@ class FeatureDataProvider: if groupBy and not _isValidIdentifier(groupBy): return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"} - conn = self._db.connection - scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db) extraWhere, extraParams = _buildFilterClauses(extraFilters) fullWhere = scopeFilter["where"] @@ -220,7 +213,7 @@ class FeatureDataProvider: t0 = time.time() try: - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: if groupBy: sql = ( f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" ' @@ -253,10 +246,6 @@ class FeatureDataProvider: _debugQueryLog("aggregateTable", tableName, { "aggregate": aggregate, "field": field, "groupBy": groupBy, }, errResult, elapsed) - try: - conn.rollback() - except Exception: - pass return errResult def queryTable( @@ -277,7 +266,6 @@ class FeatureDataProvider: ``extraFilters`` are mandatory record-level scoping filters injected by the pipeline. """ _validateTableName(tableName) - conn = self._db.connection if fields: invalid = [f for f in fields if not _isValidIdentifier(f)] @@ -287,7 +275,7 @@ class FeatureDataProvider: "error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.", } - scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db) combinedFilters = list(filters or []) + list(extraFilters or []) extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None) @@ -300,7 +288,7 @@ class FeatureDataProvider: t0 = time.time() try: - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' cur.execute(countSql, allParams) total = cur.fetchone()["count"] if cur.rowcount else 0 @@ -329,10 +317,6 @@ class FeatureDataProvider: "filters": filters, "fields": fields, "orderBy": orderBy, "limit": limit, "offset": offset, }, errResult, elapsed) - try: - conn.rollback() - except Exception: - pass return errResult @@ -343,13 +327,13 @@ class FeatureDataProvider: _instanceColCache: Dict[str, str] = {} -def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str: +def _resolveInstanceColumn(tableName: str, db=None) -> str: """Detect whether the table uses ``instanceId`` or ``featureInstanceId``.""" if tableName in _instanceColCache: return _instanceColCache[tableName] - if dbConnection: + if db: try: - with dbConnection.cursor() as cur: + with db.borrowCursor() as cur: cur.execute( "SELECT column_name FROM information_schema.columns " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " @@ -378,14 +362,14 @@ def _isValidIdentifier(name: str) -> bool: return name.isidentifier() -def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]: +def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, db=None, dbConnection=None) -> Dict[str, Any]: """Build the mandatory WHERE clause that scopes rows to the feature instance. Feature tables use either ``instanceId`` (commcoach, teamsbot) or ``featureInstanceId`` (trustee) as the FK. We detect the actual column - from ``information_schema`` when a DB connection is provided. + from ``information_schema`` when a DB connector is provided. """ - instanceCol = _resolveInstanceColumn(tableName, dbConnection) + instanceCol = _resolveInstanceColumn(tableName, db or dbConnection) conditions = [] params = [] diff --git a/modules/serviceCenter/services/serviceKnowledge/_buildTree.py b/modules/serviceCenter/services/serviceKnowledge/_buildTree.py new file mode 100644 index 00000000..9179f3d8 --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/_buildTree.py @@ -0,0 +1,1020 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Generic UDB Tree builder. + +The UDB shows three logical hierarchies as a single user-facing tree: + 1. Personal connections: UserConnection -> Service -> Folder -> File + 2. Mandate groups -> Feature instances -> FDS Workspace(*) -> FDS Table -> FDS Record + 3. (Settings/diagnostics nodes can be added later under the same model.) + +For every visible node the UI needs: + - a stable `key` (used both for expand-state and as parent reference) + - a `kind`, `label`, optional `icon` + - effective values for all three flags (neutralize, scope, ragIndexEnabled) + - whether a backing DB record exists (`dataSourceId` + `modelType`) + - whether the node has children to expand + +This module exposes one function: `getChildrenForParents(parents, ...)`. +The caller asks for the children of a list of parent keys. The orchestrator +does NOT decide what to expand; it only returns the children of what was +asked for. This keeps the contract minimal and predictable. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Tuple + +from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + resolveEffectiveForPath, + resolveEffectiveForFds, + _normalisePath, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Key encoding / decoding +# --------------------------------------------------------------------------- +# Format: "|||..." for data-bearing keys. +# Synthetic container keys use a single literal token without separator. +# +# Top-level (parent=None) returns: +# personalRoot (synthetic, groups all UserConnections) +# mgrp| (one per accessible mandate) +# +# Data-bearing: +# conn| +# svc|| +# ds||| +# mgrp| +# feat||| +# fdsws|| (synthetic '*' wildcard) +# fdstbl|| +# fdsrec||| + +_KEY_SEP = "|" + +# Stable, parseable synthetic-container key. Never encoded with `_encode` +# (no payload parts), always emitted/matched as literal. +_KEY_PERSONAL_ROOT = "personalRoot" + + +def _decode(key: str) -> Tuple[str, List[str]]: + parts = key.split(_KEY_SEP) + return parts[0], parts[1:] + + +def _encode(kind: str, *parts: str) -> str: + return _KEY_SEP.join((kind, *parts)) + + +# --------------------------------------------------------------------------- +# Sourcetype mapping (was hard-coded in frontend; now backend authority) +# --------------------------------------------------------------------------- +_SERVICE_TO_SOURCE_TYPE: Dict[str, str] = { + "sharepoint": "sharepointFolder", + "onedrive": "onedriveFolder", + "outlook": "outlookFolder", + "drive": "googleDriveFolder", + "gmail": "gmailFolder", + "files": "ftpFolder", + "clickup": "clickup", + "kdrive": "kdriveFolder", + "mail": "mailFolder", + "calendar": "calendarFolder", + "contact": "contactFolder", +} + +_SERVICE_LABELS: Dict[str, str] = { + "sharepoint": "SharePoint", + "outlook": "Outlook", + "teams": "Teams", + "onedrive": "OneDrive", + "drive": "Google Drive", + "gmail": "Gmail", + "files": "Files (FTP)", + "kdrive": "kDrive", + "calendar": "Calendar", + "contact": "Contacts", +} + + +# --------------------------------------------------------------------------- +# Per-node effective-value helpers +# --------------------------------------------------------------------------- + +def _effectiveTripletDs( + connectionId: str, + sourceType: str, + path: str, + allDs: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Return {effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled} + for an arbitrary DS coordinate (whether or not a record exists).""" + out = resolveEffectiveForPath(connectionId, sourceType, path, allDs, mode="aggregate") + return { + "effectiveNeutralize": out.get("effectiveNeutralize", False), + "effectiveScope": out.get("effectiveScope", "personal"), + "effectiveRagIndexEnabled": out.get("effectiveRagIndexEnabled", False), + } + + +def _effectiveTripletFds( + featureInstanceId: str, + tableName: str, + recordFilter: Optional[Dict[str, str]], + allFds: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Return effective-triplet for an FDS coordinate.""" + out = resolveEffectiveForFds(featureInstanceId, tableName, recordFilter, allFds, mode="aggregate") + return { + "effectiveNeutralize": out.get("effectiveNeutralize", False), + "effectiveScope": out.get("effectiveScope", "personal"), + "effectiveRagIndexEnabled": out.get("effectiveRagIndexEnabled", False), + } + + +def _findDsRecord( + allDs: List[Dict[str, Any]], + connectionId: str, + sourceType: str, + path: str, +) -> Optional[Dict[str, Any]]: + norm = _normalisePath(path) + for ds in allDs: + if ( + ds.get("connectionId") == connectionId + and ds.get("sourceType") == sourceType + and _normalisePath(ds.get("path")) == norm + ): + return ds + return None + + +def _findFdsRecord( + allFds: List[Dict[str, Any]], + featureInstanceId: str, + tableName: str, + recordFilter: Optional[Dict[str, str]] = None, +) -> Optional[Dict[str, Any]]: + """Find a FeatureDataSource record by featureInstanceId + tableName. + + `allFds` is already scoped to the workspace (loaded with + recordFilter={'workspaceInstanceId': wsInstanceId}), so the + distinguishing coordinate is featureInstanceId + tableName. + """ + target = recordFilter or None + for fds in allFds: + if ( + fds.get("featureInstanceId") == featureInstanceId + and fds.get("tableName") == tableName + and (fds.get("recordFilter") or None) == target + ): + return fds + return None + + +# --------------------------------------------------------------------------- +# Synthetic container helpers +# --------------------------------------------------------------------------- + +def _emptyTriplet() -> Dict[str, Any]: + """Synthetic container nodes carry no DB record and no inherited flags. + Backend reports neutral defaults so the UI never reads stale values for them.""" + return { + "effectiveNeutralize": False, + "effectiveScope": "personal", + "effectiveRagIndexEnabled": False, + } + + +def _syntheticNode( + key: str, + parentKey: Optional[str], + label: str, + icon: str, + displayOrder: int, + defaultExpanded: bool = False, +) -> Dict[str, Any]: + """Build a synthetic container node (no DB record, not flag-toggleable).""" + return { + "key": key, + "kind": "synthRoot", + "parentKey": parentKey, + "label": label, + "icon": icon, + "hasChildren": True, + "dataSourceId": None, + "modelType": None, + **_emptyTriplet(), + "supportsRag": False, + "canBeAdded": False, + "displayOrder": displayOrder, + "defaultExpanded": defaultExpanded, + } + + +# --------------------------------------------------------------------------- +# Top-level (parent = None) -> personalRoot + mandate groups (flat layout) +# --------------------------------------------------------------------------- + +def _topLevel( + instanceId: str, + context: Any, + rootIf: Any, + _allDs: List[Dict[str, Any]], + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return the visible top-level: 'personalRoot' first, then one node per + accessible mandate group. Both layers are marked `defaultExpanded=True` + so the UI opens down to the data-source level on first render. + """ + nodes: List[Dict[str, Any]] = [ + _syntheticNode( + key=_KEY_PERSONAL_ROOT, + parentKey=None, + label=resolveTextSafe("Persönliche Quellen"), + icon="person", + displayOrder=0, + defaultExpanded=True, + ) + ] + nodes.extend(_listMandateGroups(instanceId, context, rootIf, allFds)) + return nodes + + +# --------------------------------------------------------------------------- +# Children of personalRoot -> active UserConnections +# --------------------------------------------------------------------------- + +def _personalRootChildren( + instanceId: str, + context: Any, + allDs: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return one node per active UserConnection of the current user.""" + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + + mandateId = getattr(context, "mandateId", "") or "" + ctx = ServiceCenterContext( + user=context.user, + mandate_id=mandateId, + feature_instance_id=instanceId, + ) + chatService = getService("chat", ctx) + connections = chatService.getUserConnections() or [] + + nodes: List[Dict[str, Any]] = [] + for c in connections: + conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {}) + status = conn.get("status") + if hasattr(status, "value"): + status = status.value + if status != "active": + continue + authority = conn.get("authority") + if hasattr(authority, "value"): + authority = authority.value + connId = conn.get("id") or "" + label = conn.get("externalEmail") or conn.get("externalUsername") or authority or "" + # Connection root = path '/' on its authority sourceType. + triplet = _effectiveTripletDs(connId, str(authority), "/", allDs) + rec = _findDsRecord(allDs, connId, str(authority), "/") + nodes.append({ + "key": _encode("conn", connId), + "kind": "connection", + "parentKey": _KEY_PERSONAL_ROOT, + "label": label, + "icon": str(authority), + "hasChildren": True, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "DataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "authority": authority, + "connectionId": connId, + }) + return nodes + + +# --------------------------------------------------------------------------- +# Mandate-group nodes (rendered top-level next to personalRoot) +# --------------------------------------------------------------------------- + +def _listMandateGroups( + _instanceId: str, + context: Any, + rootIf: Any, + _allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return one mandate-group node per accessible mandate that has at least + one enabled feature instance with registered DATA objects. + + Emitted at the top level (parentKey=None). `defaultExpanded=True` so the + UI shows feature-instance children (= mandate data sources) without a + second user click. + """ + from modules.security.rbacCatalog import getCatalogService + from modules.datamodels.datamodelUam import Mandate + + userId = str(context.user.id) + catalog = getCatalogService() + featureCodesWithData = catalog.getFeaturesWithDataObjects() + userMandates = rootIf.getUserMandates(userId) + + wsMandateId = getattr(context, "mandateId", None) + allowedMandateIds = {um.mandateId for um in (userMandates or [])} + if wsMandateId and wsMandateId in allowedMandateIds: + allowedMandateIds = {wsMandateId} + + mandateLabels: Dict[str, str] = {} + for um in userMandates or []: + if um.mandateId not in allowedMandateIds: + continue + try: + rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId}) + if rows: + m = rows[0] + mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId + except Exception: + mandateLabels[um.mandateId] = um.mandateId + + nodes: List[Dict[str, Any]] = [] + seenMandates: set = set() + for um in userMandates or []: + mid = um.mandateId + if mid in seenMandates or mid not in allowedMandateIds: + continue + seenMandates.add(mid) + instances = rootIf.getFeatureInstancesByMandate(mid) + hasFeature = False + for inst in instances: + if inst.enabled and inst.featureCode in featureCodesWithData: + fa = rootIf.getFeatureAccess(userId, inst.id) + if fa and fa.enabled: + hasFeature = True + break + if not hasFeature: + continue + nodes.append({ + "key": _encode("mgrp", mid), + "kind": "mandateGroup", + "parentKey": None, + "label": mandateLabels.get(mid, mid), + "icon": "mandate", + "hasChildren": True, + "dataSourceId": None, + "modelType": None, + **_emptyTriplet(), + "supportsRag": False, + "canBeAdded": False, + "mandateId": mid, + "defaultExpanded": True, + }) + return nodes + + +# --------------------------------------------------------------------------- +# Children of a connection -> services +# --------------------------------------------------------------------------- + +async def _connectionServices( + instanceId: str, + context: Any, + connectionId: str, + allDs: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + from modules.connectors.connectorResolver import ConnectorResolver + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + + mandateId = getattr(context, "mandateId", "") or "" + ctx = ServiceCenterContext( + user=context.user, + mandate_id=mandateId, + feature_instance_id=instanceId, + ) + chatService = getService("chat", ctx) + securityService = getService("security", ctx) + from modules.features.workspace.routeFeatureWorkspace import _buildResolverDbInterface + dbInterface = _buildResolverDbInterface(chatService) + resolver = ConnectorResolver(securityService, dbInterface) + try: + provider = await resolver.resolve(connectionId) + services = provider.getAvailableServices() + except Exception as exc: + logger.error("Tree: cannot resolve services for connection %s: %s", connectionId, exc) + return [] + + nodes: List[Dict[str, Any]] = [] + for service in services or []: + sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service) + triplet = _effectiveTripletDs(connectionId, sourceType, "/", allDs) + rec = _findDsRecord(allDs, connectionId, sourceType, "/") + nodes.append({ + "key": _encode("svc", connectionId, service), + "kind": "service", + "parentKey": _encode("conn", connectionId), + "label": _SERVICE_LABELS.get(service, service), + "icon": service, + "hasChildren": True, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "DataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "connectionId": connectionId, + "service": service, + "sourceType": sourceType, + "path": "/", + }) + return nodes + + +# --------------------------------------------------------------------------- +# Children of a folder/service -> next-level folders+files via browse +# --------------------------------------------------------------------------- + +async def _browseChildren( + instanceId: str, + context: Any, + connectionId: str, + service: str, + sourceType: str, + parentPath: str, + allDs: List[Dict[str, Any]], + parentKey: Optional[str] = None, +) -> List[Dict[str, Any]]: + from modules.connectors.connectorResolver import ConnectorResolver + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + + mandateId = getattr(context, "mandateId", "") or "" + ctx = ServiceCenterContext( + user=context.user, + mandate_id=mandateId, + feature_instance_id=instanceId, + ) + chatService = getService("chat", ctx) + securityService = getService("security", ctx) + from modules.features.workspace.routeFeatureWorkspace import _buildResolverDbInterface + dbInterface = _buildResolverDbInterface(chatService) + resolver = ConnectorResolver(securityService, dbInterface) + try: + adapter = await resolver.resolveService(connectionId, service) + entries = await adapter.browse(parentPath, filter=None) + except Exception as exc: + logger.error("Tree: cannot browse %s on connection %s path=%s: %s", service, connectionId, parentPath, exc) + return [] + + # Children parentKey must equal the key the caller asked for (= the + # currently-expanded node in the UI). If the caller doesn't pass an + # explicit key, fall back to the encoded ds-coordinate. + effectiveParentKey = parentKey if parentKey is not None else _encode("ds", connectionId, sourceType, parentPath) + nodes: List[Dict[str, Any]] = [] + for e in entries or []: + path = getattr(e, "path", "") or "" + kind = "folder" if getattr(e, "isFolder", False) else "file" + triplet = _effectiveTripletDs(connectionId, sourceType, path, allDs) + rec = _findDsRecord(allDs, connectionId, sourceType, path) + nodes.append({ + "key": _encode("ds", connectionId, sourceType, path), + "kind": kind, + "parentKey": effectiveParentKey, + "label": getattr(e, "name", "") or path, + "icon": kind, + "hasChildren": kind == "folder", + "dataSourceId": rec.get("id") if rec else None, + "modelType": "DataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "connectionId": connectionId, + "service": service, + "sourceType": sourceType, + "path": path, + }) + return nodes + + +# --------------------------------------------------------------------------- +# Mandate group -> feature connections +# --------------------------------------------------------------------------- + +def _featureConnectionsForMandate( + instanceId: str, + context: Any, + rootIf: Any, + mandateId: str, + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + from modules.security.rbacCatalog import getCatalogService + + userId = str(context.user.id) + catalog = getCatalogService() + featureCodesWithData = catalog.getFeaturesWithDataObjects() + instances = rootIf.getFeatureInstancesByMandate(mandateId) + + parentKey = _encode("mgrp", mandateId) + nodes: List[Dict[str, Any]] = [] + for inst in instances or []: + if not inst.enabled: + continue + if inst.featureCode not in featureCodesWithData: + continue + fa = rootIf.getFeatureAccess(userId, inst.id) + if not fa or not fa.enabled: + continue + # Effective values come from the FDS workspace-wildcard for this featureInstanceId + wsId = inst.id + triplet = _effectiveTripletFds(wsId, "*", None, allFds) + rec = _findFdsRecord(allFds, wsId, "*", None) + featureDef = catalog.getFeatureDefinition(inst.featureCode) or {} + nodes.append({ + "key": _encode("feat", mandateId, inst.featureCode, inst.id), + "kind": "featureNode", + "parentKey": parentKey, + "label": inst.label or inst.featureCode, + "icon": featureDef.get("icon", "mdi-database"), + "hasChildren": True, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "FeatureDataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "featureInstanceId": wsId, + "featureCode": inst.featureCode, + "mandateId": mandateId, + "tableName": "*", + }) + return nodes + + +# --------------------------------------------------------------------------- +# Feature node -> tables +# --------------------------------------------------------------------------- + +def _featureTables( + context: Any, + rootIf: Any, + parentKey: str, + featureInstanceId: str, + featureCode: str, + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + from modules.security.rbacCatalog import getCatalogService + + inst = rootIf.getFeatureInstance(featureInstanceId) + if not inst: + return [] + catalog = getCatalogService() + try: + from modules.security.rbac import RbacClass + from modules.security.rootAccess import getRootDbAppConnector + dbApp = getRootDbAppConnector() + rbac = RbacClass(dbApp, dbApp=dbApp) + accessible = catalog.getAccessibleDataObjects( + featureCode=inst.featureCode, + rbacInstance=rbac, + user=context.user, + mandateId=str(inst.mandateId) if inst.mandateId else "", + featureInstanceId=featureInstanceId, + ) + except Exception: + accessible = catalog.getDataObjects(inst.featureCode) + + accessibleKeys = {obj.get("objectKey", "") for obj in accessible} + + nodes: List[Dict[str, Any]] = [] + for obj in catalog.getDataObjects(inst.featureCode): + meta = obj.get("meta", {}) + if meta.get("wildcard") or meta.get("isGroup"): + continue + objectKey = obj.get("objectKey", "") + if objectKey not in accessibleKeys: + continue + tableName = meta.get("table", "") + if not tableName: + continue + triplet = _effectiveTripletFds(featureInstanceId, tableName, None, allFds) + rec = _findFdsRecord(allFds, featureInstanceId, tableName, None) + fields = meta.get("fields") if isinstance(meta, dict) else None + hasFields = bool(isinstance(fields, list) and len(fields) > 0) + # Surface the persisted per-field neutralize list so the UI can + # render & toggle field-level icons without an extra GET. + neutralizeFields: List[str] = [] + if rec and isinstance(rec.get("neutralizeFields"), list): + neutralizeFields = [f for f in rec["neutralizeFields"] if isinstance(f, str)] + nodes.append({ + "key": _encode("fdstbl", featureInstanceId, tableName), + "kind": "fdsTable", + "parentKey": parentKey, + "label": resolveTextSafe(obj.get("label", "")) or tableName, + "icon": "table", + # Children = the per-column field nodes. Only emitted when the + # data-object metadata declared a non-empty `fields` list. + "hasChildren": hasFields, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "FeatureDataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "featureInstanceId": featureInstanceId, + "featureCode": featureCode, + "tableName": tableName, + "objectKey": objectKey, + "neutralizeFields": neutralizeFields, + }) + return nodes + + +def _featureTableFields( + parentKey: str, + featureInstanceId: str, + tableName: str, + fieldNames: List[str], + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Emit one node per declared column of a feature data table. + + Per-field neutralize semantics: + - The table-level FDS record carries `neutralizeFields: List[str]`. + - A field is "effectively neutralized" iff its name is in that list + OR the table's effective `neutralize` is True (blanket). + - Only `neutralize` is meaningful per-field; `scope` and `ragIndexEnabled` + are inherited from the parent table and not toggleable here. + """ + rec = _findFdsRecord(allFds, featureInstanceId, tableName, None) + tableNeutralize = bool(rec.get("neutralize")) if rec else False + neutralizeFields = rec.get("neutralizeFields") if rec else None + if not isinstance(neutralizeFields, list): + neutralizeFields = [] + + nodes: List[Dict[str, Any]] = [] + for field in fieldNames: + if not field: + continue + fieldNeutralized = bool(tableNeutralize or field in neutralizeFields) + nodes.append({ + "key": _encode("fdsfld", featureInstanceId, tableName, field), + "kind": "fdsField", + "parentKey": parentKey, + "label": field, + "icon": "field", + "hasChildren": False, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "FeatureDataSource" if rec else None, + "effectiveNeutralize": fieldNeutralized, + # Field-level scope/RAG do not exist as a concept; the FE hides + # those affordances when supportsRag=False. We still need + # `effectiveScope` + `effectiveRagIndexEnabled` for the + # contract; they reflect the parent's effective values so the + # backend stays single source of truth. + "effectiveScope": "personal", + "effectiveRagIndexEnabled": False, + "supportsRag": False, + "canBeAdded": rec is None, + "featureInstanceId": featureInstanceId, + "tableName": tableName, + "fieldName": field, + }) + return nodes + + +def resolveTextSafe(label: Any) -> str: + try: + from modules.shared.i18nRegistry import resolveText + return resolveText(label) + except Exception: + return str(label or "") + + +# --------------------------------------------------------------------------- +# Public entrypoint +# --------------------------------------------------------------------------- + +async def getChildrenForParents( + instanceId: str, + parents: List[Optional[str]], + context: Any, +) -> Dict[str, List[Dict[str, Any]]]: + """Return per-parent children lists. + + `parents` is a list with `None` representing the top-level. Order is preserved. + Returns a dict keyed by parent key (or '__root__' for None). + + Each child is a fully-rendered TreeNode dict (see module docstring for shape). + """ + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelDataSource import DataSource + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + rootIf = getRootInterface() + + # Pre-load DS (per user) and FDS (per workspace) once for the whole request. + userId = str(context.user.id) + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"userId": userId}) or [] + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": instanceId}) or [] + + out: Dict[str, List[Dict[str, Any]]] = {} + + for parentKey in parents: + if parentKey is None: + try: + out["__root__"] = _topLevel(instanceId, context, rootIf, allDs, allFds) + except Exception as exc: + logger.exception("Tree top-level failed: %s", exc) + out["__root__"] = [] + continue + + try: + kind, parts = _decode(parentKey) + except Exception: + out[parentKey] = [] + continue + + try: + if parentKey == _KEY_PERSONAL_ROOT: + out[parentKey] = _personalRootChildren(instanceId, context, allDs) + + elif kind == "conn" and len(parts) == 1: + out[parentKey] = await _connectionServices(instanceId, context, parts[0], allDs) + + elif kind == "svc" and len(parts) == 2: + connId, service = parts + sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service) + out[parentKey] = await _browseChildren( + instanceId, context, connId, service, sourceType, "/", allDs, + parentKey=parentKey, + ) + + elif kind == "ds" and len(parts) == 3: + connId, sourceType, path = parts + # Determine service from sourceType (reverse map) + service = _reverseService(sourceType) + out[parentKey] = await _browseChildren( + instanceId, context, connId, service, sourceType, path, allDs, + parentKey=parentKey, + ) + + elif kind == "mgrp" and len(parts) == 1: + out[parentKey] = _featureConnectionsForMandate(instanceId, context, rootIf, parts[0], allFds) + + elif kind == "feat" and len(parts) == 3: + _mandateId, featureCode, featureInstanceId = parts + out[parentKey] = _featureTables(context, rootIf, parentKey, featureInstanceId, featureCode, allFds) + + elif kind == "fdstbl" and len(parts) == 2: + featureInstanceId, tableName = parts + fieldNames = _resolveTableFieldNames(featureInstanceId, tableName, rootIf) + out[parentKey] = _featureTableFields( + parentKey, featureInstanceId, tableName, fieldNames, allFds, + ) + + else: + out[parentKey] = [] + except Exception as exc: + logger.exception("Tree children for %s failed: %s", parentKey, exc) + out[parentKey] = [] + + return out + + +def _reverseService(sourceType: str) -> str: + for svc, st in _SERVICE_TO_SOURCE_TYPE.items(): + if st == sourceType: + return svc + return sourceType + + +def _resolveTableFieldNames(featureInstanceId: str, tableName: str, rootIf: Any) -> List[str]: + """Look up the declared column list for a (featureInstance, tableName) + pair via the RBAC catalog data-object metadata. Returns empty list when + the catalog has no entry (e.g. wildcard-only feature).""" + from modules.security.rbacCatalog import getCatalogService + inst = rootIf.getFeatureInstance(featureInstanceId) + if not inst: + return [] + catalog = getCatalogService() + for obj in catalog.getDataObjects(inst.featureCode) or []: + meta = obj.get("meta", {}) if isinstance(obj, dict) else {} + if meta.get("table") == tableName: + fields = meta.get("fields") + if isinstance(fields, list): + return [f for f in fields if isinstance(f, str) and f] + return [] + return [] + + +# --------------------------------------------------------------------------- +# Attribute-only refresh: given node keys, return current effective values +# --------------------------------------------------------------------------- + +async def getAttributesForKeys( + instanceId: str, + keys: List[str], + context: Any, +) -> Dict[str, Dict[str, Any]]: + """Return effective attribute values for a list of node keys. + + Used by the frontend after a toggle to refresh only attributes (neutralize, + scope, ragIndexEnabled) without reloading the tree structure. For container + nodes (personalRoot, mgrp), aggregates child values and returns 'mixed' + when children diverge.""" + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelDataSource import DataSource + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + rootIf = getRootInterface() + userId = str(context.user.id) + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"userId": userId}) or [] + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": instanceId}) or [] + + result: Dict[str, Dict[str, Any]] = {} + + for key in keys: + try: + attrs = _resolveAttrsForKey(key, allDs, allFds, instanceId, context, rootIf) + if attrs is not None: + result[key] = attrs + if "mixed" in str(attrs.values()): + logger.info("getAttributesForKeys key=%s returned MIXED: %s", key, attrs) + except Exception as exc: + logger.warning("getAttributesForKeys failed for key=%s: %s", key, exc) + + logger.info("getAttributesForKeys: %d keys requested, %d resolved", len(keys), len(result)) + return result + + +def _resolveAttrsForKey( + key: str, + allDs: List[Dict[str, Any]], + allFds: List[Dict[str, Any]], + instanceId: str, + context: Any, + rootIf: Any, +) -> Optional[Dict[str, Any]]: + """Resolve effective attributes for a single node key.""" + if key == _KEY_PERSONAL_ROOT: + return _aggregatePersonalRoot(allDs) + + try: + kind, parts = _decode(key) + except Exception: + return None + + if kind == "mgrp" and len(parts) == 1: + return _aggregateMandateGroup(parts[0], allFds, instanceId, context, rootIf) + + if kind == "conn" and len(parts) == 1: + connId = parts[0] + return _aggregateConnection(connId, allDs) + + if kind == "svc" and len(parts) == 2: + connId, service = parts + sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service) + return _effectiveTripletDs(connId, sourceType, "/", allDs) + + if kind == "ds" and len(parts) == 3: + connId, sourceType, path = parts + return _effectiveTripletDs(connId, sourceType, path, allDs) + + if kind == "feat" and len(parts) == 3: + _mandateId, _featureCode, featureInstanceId = parts + return _effectiveTripletFds(featureInstanceId, "*", None, allFds) + + if kind == "fdsws" and len(parts) == 2: + workspaceInstanceId, _featureCode = parts + return _effectiveTripletFds(workspaceInstanceId, "*", None, allFds) + + if kind == "fdstbl" and len(parts) == 2: + featureInstanceId, tableName = parts + return _effectiveTripletFds(featureInstanceId, tableName, None, allFds) + + if kind == "fdsrec" and len(parts) == 3: + featureInstanceId, tableName, recordId = parts + return _effectiveTripletFds(featureInstanceId, tableName, {"objectKey": recordId}, allFds) + + if kind == "fdsfld" and len(parts) >= 3: + featureInstanceId, tableName = parts[0], parts[1] + fieldName = parts[2] if len(parts) > 2 else "" + parentFds = None + for fds in allFds: + if (fds.get("featureInstanceId") == featureInstanceId + and (fds.get("tableName") or "") == tableName + and fds.get("recordFilter") is None): + parentFds = fds + break + neutralizeFields = (parentFds.get("neutralizeFields") or []) if parentFds else [] + return {"effectiveNeutralize": fieldName in neutralizeFields} + + return None + + +def _aggregateConnection(connId: str, allDs: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate effective values for a connection node. + + If the connection has an authority-level DS record (path="/"), use the + standard aggregate mode on it (which already handles subtree correctly). + Otherwise compute effective values for each child DS using walk mode and + aggregate them manually.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + getEffectiveFlag, _AUTHORITY_SOURCE_TYPES, + ) + connRecords = [d for d in allDs if d.get("connectionId") == connId] + if not connRecords: + return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False} + + rootRec = None + for r in connRecords: + st = r.get("sourceType", "") + if st in _AUTHORITY_SOURCE_TYPES and _normalisePath(r.get("path", "")) == "/": + rootRec = r + break + + if rootRec: + return _effectiveTripletDs(connId, rootRec.get("sourceType", ""), "/", allDs) + + neutralizeVals = set() + scopeVals = set() + ragVals = set() + for r in connRecords: + neutralizeVals.add(getEffectiveFlag(r, "neutralize", allDs, mode="walk")) + scopeVals.add(getEffectiveFlag(r, "scope", allDs, mode="walk")) + ragVals.add(getEffectiveFlag(r, "ragIndexEnabled", allDs, mode="walk")) + return { + "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False), + } + + +def _aggregatePersonalRoot(allDs: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate effective values across all personal DS records. + + Uses getEffectiveFlag in aggregate mode on each connection-root record. + If no root records exist, aggregates walk-effective values of all records.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + getEffectiveFlag, _AUTHORITY_SOURCE_TYPES, + ) + if not allDs: + return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False} + + rootRecords = [ + d for d in allDs + if d.get("sourceType", "") in _AUTHORITY_SOURCE_TYPES + and _normalisePath(d.get("path", "")) == "/" + ] + targets = rootRecords if rootRecords else allDs + + neutralizeVals = set() + scopeVals = set() + ragVals = set() + for ds in targets: + neutralizeVals.add(getEffectiveFlag(ds, "neutralize", allDs, mode="aggregate")) + scopeVals.add(getEffectiveFlag(ds, "scope", allDs, mode="aggregate")) + ragVals.add(getEffectiveFlag(ds, "ragIndexEnabled", allDs, mode="aggregate")) + return { + "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False), + } + + +def _aggregateMandateGroup( + mandateId: str, + allFds: List[Dict[str, Any]], + instanceId: str, + context: Any, + rootIf: Any, +) -> Dict[str, Any]: + """Aggregate effective values across FDS records belonging to this mandate group. + + Uses getEffectiveFlagFds in aggregate mode on each workspace-level FDS + (tableName="*") that belongs to the given mandateId. This correctly resolves + inherited values from the full FDS hierarchy.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + + groupFds = [f for f in allFds if f.get("mandateId") == mandateId] + workspaceLevelFds = [f for f in groupFds if (f.get("tableName") or "") == "*"] + targets = workspaceLevelFds if workspaceLevelFds else groupFds + + if not targets: + return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False} + + neutralizeVals = set() + scopeVals = set() + ragVals = set() + for fds in targets: + neutralizeVals.add(getEffectiveFlagFds(fds, "neutralize", allFds, mode="aggregate")) + scopeVals.add(getEffectiveFlagFds(fds, "scope", allFds, mode="aggregate")) + ragVals.add(getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate")) + return { + "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False), + } diff --git a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py index 00180c9f..64a0019c 100644 --- a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py +++ b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py @@ -3,9 +3,15 @@ """Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled, scope). Three-state flags allow tree elements to either set an explicit value or -inherit the value from their nearest ancestor in the path hierarchy. The -walker (RAG/Neutralize) and routes resolve the *effective* value; the cascade -helper resets explicit descendant values when a parent is toggled. +inherit the value from their nearest ancestor in the path hierarchy. + +Modes: + - 'walk' (default): resolves the *concrete* effective value per-item + (never returns 'mixed'). Used by backend consumers (RAG walker, + neutralization pipeline, scope filter, etc.). + - 'aggregate': resolves the *display* effective value per-item. If the + item has descendants with differing walk-effective values, returns + 'mixed'. Used by listing endpoints and PATCH responses for the UI. Path-traversal rules: - A DataSource is identified by `(connectionId, sourceType, path)`. @@ -17,11 +23,12 @@ Path-traversal rules: """ import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple logger = logging.getLogger(__name__) _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope") +_INHERITABLE_FDS_FLAGS = ("neutralize", "ragIndexEnabled", "scope") # Connection-root DataSources carry the authority as their sourceType # (e.g. 'msft', 'google'). They sit one level above all service DataSources @@ -29,6 +36,12 @@ _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope") # cross sourceType boundaries — but ONLY from these authority roots. _AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"}) +Mode = Literal["walk", "aggregate"] + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- def _normalisePath(path: Optional[str]) -> str: """Normalize a DataSource path to '/'-prefixed, no trailing slash (except root).""" @@ -49,10 +62,7 @@ def _flagDefault(flag: str) -> Any: def _isExplicit(value: Any) -> bool: - """A flag value is explicit when it is not None. - - Note: legacy rows may carry empty-string scope; treat as inherit too. - """ + """A flag value is explicit when it is not None/empty-string.""" if value is None: return False if isinstance(value, str) and value == "": @@ -66,6 +76,21 @@ def _getRecordValue(rec: Any, key: str) -> Any: return getattr(rec, key, None) +def _isAncestorPath(ancestor: str, descendant: str) -> bool: + """True iff `ancestor` is a strict path-prefix of `descendant`.""" + if ancestor == descendant: + return False + if ancestor == "/": + return descendant != "/" + return descendant.startswith(ancestor + "/") + + +def _pathDepth(path: str) -> int: + if path == "/": + return 0 + return path.count("/") + + def _findAncestorChain( rec: Dict[str, Any], allDs: Iterable[Dict[str, Any]], @@ -74,15 +99,13 @@ def _findAncestorChain( ordered nearest-first. Two ancestor relations are merged: - 1) **same-sourceType path-ancestor** — strict path-prefix within the - same service tree (sharepointFolder, gmailFolder, ...). - 2) **connection-root ancestor** — a DS with `path='/'` and - `sourceType` ∈ authority set (msft, google, ...) is the parent of - every other DS in that connection regardless of sourceType, so a - toggle on the connection node propagates to all services beneath. + 1) same-sourceType path-ancestor — strict path-prefix within the + same service tree. + 2) connection-root ancestor — a DS with `path='/'` and + `sourceType` in authority set is the parent of every other DS + in that connection regardless of sourceType. - The connection-root is always the most distant ancestor and therefore - sorts after any same-sourceType ancestors. + The connection-root is always the most distant ancestor. """ recPath = _normalisePath(_getRecordValue(rec, "path")) recSourceType = _getRecordValue(rec, "sourceType") @@ -114,36 +137,89 @@ def _findAncestorChain( return chain -def _isAncestorPath(ancestor: str, descendant: str) -> bool: - """True iff `ancestor` is a strict path-prefix of `descendant`. +def _isDescendantDs(parentRec: Dict[str, Any], candidate: Dict[str, Any]) -> bool: + """True iff `candidate` is a descendant of `parentRec` in the DS hierarchy.""" + parentSourceType = _getRecordValue(parentRec, "sourceType") + parentPath = _normalisePath(_getRecordValue(parentRec, "path")) + parentConnectionId = _getRecordValue(parentRec, "connectionId") + parentId = _getRecordValue(parentRec, "id") - '/' is ancestor of every non-root path. For non-root prefixes, the - descendant must continue with '/' so '/foo' isn't treated as ancestor of - '/foobar'. - """ - if ancestor == descendant: + candId = _getRecordValue(candidate, "id") + if candId == parentId: + return False + if _getRecordValue(candidate, "connectionId") != parentConnectionId: return False - if ancestor == "/": - return descendant != "/" - return descendant.startswith(ancestor + "/") + candSourceType = _getRecordValue(candidate, "sourceType") + candPath = _normalisePath(_getRecordValue(candidate, "path")) + + parentIsConnectionRoot = ( + parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/" + ) + if parentIsConnectionRoot: + return True + if candSourceType != parentSourceType: + return False + return _isAncestorPath(parentPath, candPath) + + +# --------------------------------------------------------------------------- +# DataSource: getEffectiveFlag +# --------------------------------------------------------------------------- def getEffectiveFlag( rec: Dict[str, Any], flag: str, sameConnectionDs: Iterable[Dict[str, Any]], + mode: Mode = "walk", ) -> Any: """Resolve the effective value of a flag via path-traversal. - Order: own value (if explicit) → nearest ancestor with explicit value → - static default (`False` or `'personal'`). + mode='walk': own explicit → nearest ancestor explicit → default. + Always returns a concrete value (never 'mixed'). + mode='aggregate': same as walk for leaf value, but if the item has + descendants whose walk-effective values differ from + each other, returns 'mixed'. """ if flag not in _INHERITABLE_FLAGS: raise ValueError(f"Unknown inheritable flag: {flag}") + + allDs = list(sameConnectionDs) + + walkValue = _resolveWalkValue(rec, flag, allDs) + + if mode == "walk": + return walkValue + + # mode == 'aggregate': check subtree for heterogeneous effective values + descendants = [d for d in allDs if _isDescendantDs(rec, d)] + if not descendants: + return walkValue + + subtreeValues = set() + subtreeValues.add(_normaliseForComparison(walkValue)) + for desc in descendants: + descEffective = _resolveWalkValue(desc, flag, allDs) + subtreeValues.add(_normaliseForComparison(descEffective)) + if len(subtreeValues) > 1: + recId = _getRecordValue(rec, "id") + descId = _getRecordValue(desc, "id") + descOwnVal = _getRecordValue(desc, flag) + logger.info( + "DS aggregate MIXED for rec=%s flag=%s: walkValue=%s, " + "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s", + recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues, + ) + return "mixed" + return walkValue + + +def _resolveWalkValue(rec: Dict[str, Any], flag: str, allDs: List[Dict[str, Any]]) -> Any: + """Core walk resolution: own explicit → ancestor chain → default.""" own = _getRecordValue(rec, flag) if _isExplicit(own): return own - chain = _findAncestorChain(rec, sameConnectionDs) + chain = _findAncestorChain(rec, allDs) for ancestor in chain: ancestorVal = _getRecordValue(ancestor, flag) if _isExplicit(ancestorVal): @@ -151,69 +227,112 @@ def getEffectiveFlag( return _flagDefault(flag) +def _normaliseForComparison(value: Any) -> Any: + """Normalize values for set-comparison (bool as int to avoid hash issues).""" + if isinstance(value, bool): + return int(value) + return value + + +# --------------------------------------------------------------------------- +# DataSource: cascadeResetDescendants (bottom-up) +# --------------------------------------------------------------------------- + def cascadeResetDescendants( rootIf: Any, parentRec: Dict[str, Any], flag: str, -) -> int: +) -> List[str]: """Reset all explicit descendant values of `flag` to NULL (= inherit). - Descendant relation mirrors `_findAncestorChain`: - - Connection-root (`path='/'` AND `sourceType` ∈ authorities) is parent - of every other DS in that connection (cross-sourceType cascade). - - Otherwise: same-sourceType strict path-descendants only. + Reset order: bottom-up (deepest first) for crash safety. + The parent itself is NOT modified here — the caller sets the master value + after this function returns. - Only the targeted `flag` is reset; other flags on the descendant are - untouched. - - Returns the number of records updated. + Returns list of reset record IDs in bottom-up order. """ if flag not in _INHERITABLE_FLAGS: raise ValueError(f"Unknown inheritable flag: {flag}") from modules.datamodels.datamodelDataSource import DataSource connectionId = _getRecordValue(parentRec, "connectionId") - parentSourceType = _getRecordValue(parentRec, "sourceType") - parentPath = _normalisePath(_getRecordValue(parentRec, "path")) parentId = _getRecordValue(parentRec, "id") - if not connectionId or not parentSourceType: - return 0 - - parentIsConnectionRoot = ( - parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/" - ) + if not connectionId: + return [] siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) - affected = 0 + + toReset: List[Tuple[int, str]] = [] for sib in siblings: - sibId = _getRecordValue(sib, "id") - if sibId == parentId: + if not _isDescendantDs(parentRec, sib): continue - sibSourceType = _getRecordValue(sib, "sourceType") - sibPath = _normalisePath(_getRecordValue(sib, "path")) - if parentIsConnectionRoot: - # Connection-root resets everything else under this connection. - pass - else: - if sibSourceType != parentSourceType: - continue - if not _isAncestorPath(parentPath, sibPath): - continue sibVal = _getRecordValue(sib, flag) if not _isExplicit(sibVal): continue + sibId = _getRecordValue(sib, "id") + sibPath = _normalisePath(_getRecordValue(sib, "path")) + toReset.append((_pathDepth(sibPath), sibId)) + + # Sort deepest first (bottom-up) + toReset.sort(key=lambda x: x[0], reverse=True) + + resetIds: List[str] = [] + for _, sibId in toReset: try: rootIf.db.recordModify(DataSource, sibId, {flag: None}) - affected += 1 + resetIds.append(sibId) except Exception as exc: logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc) - if affected: - logger.info( - "Cascade-reset %s on %d descendants of DataSource (connectionId=%s, sourceType=%s, path=%s, connectionRoot=%s)", - flag, affected, connectionId, parentSourceType, parentPath, parentIsConnectionRoot, - ) - return affected + if resetIds: + logger.info( + "Cascade-reset %s on %d descendants of DataSource %s (bottom-up)", + flag, len(resetIds), parentId, + ) + return resetIds + + +# --------------------------------------------------------------------------- +# DataSource: collectAncestorChain (for updatedAncestors in PATCH response) +# --------------------------------------------------------------------------- + +def collectAncestorChain( + rec: Dict[str, Any], + sameConnectionDs: Iterable[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return ancestor chain of `rec` (nearest-first), same as internal helper. + + Exposed for PATCH endpoints to compute updatedAncestors. + """ + return _findAncestorChain(rec, sameConnectionDs) + + +# --------------------------------------------------------------------------- +# DataSource: buildEffectiveByConnection +# --------------------------------------------------------------------------- + +def buildEffectiveByConnection( + dataSources: Iterable[Dict[str, Any]], + flag: str, + mode: Mode = "walk", +) -> Dict[str, Any]: + """Pre-compute the effective value of `flag` for every DataSource id. + + Uses the specified mode. O(N^2) worst case but N is bounded per connection. + """ + if flag not in _INHERITABLE_FLAGS: + raise ValueError(f"Unknown inheritable flag: {flag}") + allDs = list(dataSources) + out: Dict[str, Any] = {} + for rec in allDs: + recId = _getRecordValue(rec, "id") + out[recId] = getEffectiveFlag(rec, flag, allDs, mode=mode) + return out + + +# --------------------------------------------------------------------------- +# FeatureDataSource helpers +# --------------------------------------------------------------------------- def _fdsClassify(fds: Dict[str, Any]) -> str: """Return 'workspace' | 'table' | 'record' based on the FDS identifier shape.""" @@ -229,14 +348,14 @@ def _fdsClassify(fds: Dict[str, Any]) -> str: def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool: """Return True iff `parent` FDS is a strict ancestor of `child` FDS. - Hierarchy within one `workspaceInstanceId`: - workspace-wildcard (tableName='*') → table-wildcard (tableName='X', !recordFilter) - → record-fds (tableName='X', recordFilter.id=...) - table-wildcard (tableName='X') → record-fds (tableName='X', recordFilter.id=...) + Hierarchy within one featureInstanceId (allFds is already scoped to + a single workspace): + feature-wildcard (tableName='*') -> table-wildcard / record-fds + table-wildcard (tableName='X') -> record-fds (tableName='X') """ - parentWsId = _getRecordValue(parent, "workspaceInstanceId") - childWsId = _getRecordValue(child, "workspaceInstanceId") - if not parentWsId or parentWsId != childWsId: + parentFiId = _getRecordValue(parent, "featureInstanceId") + childFiId = _getRecordValue(child, "featureInstanceId") + if not parentFiId or parentFiId != childFiId: return False if _getRecordValue(parent, "id") == _getRecordValue(child, "id"): return False @@ -251,23 +370,68 @@ def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool: return False +def _fdsDepth(fds: Dict[str, Any]) -> int: + kind = _fdsClassify(fds) + if kind == "workspace": + return 0 + if kind == "table": + return 1 + return 2 + + +# --------------------------------------------------------------------------- +# FeatureDataSource: getEffectiveFlagFds +# --------------------------------------------------------------------------- + def getEffectiveFlagFds( rec: Dict[str, Any], flag: str, sameWorkspaceFds: Iterable[Dict[str, Any]], + mode: Mode = "walk", ) -> Any: """Resolve effective value of a FeatureDataSource flag. - Order: own (if explicit) → table-wildcard (if explicit) → - workspace-wildcard (if explicit) → static default. + mode='walk': own explicit -> table-wildcard -> workspace-wildcard -> default. + mode='aggregate': same but returns 'mixed' if descendants diverge. """ - if flag not in ("neutralize", "scope"): + if flag not in _INHERITABLE_FDS_FLAGS: raise ValueError(f"Unknown inheritable FDS flag: {flag}") + + allFds = list(sameWorkspaceFds) + walkValue = _resolveWalkValueFds(rec, flag, allFds) + + if mode == "walk": + return walkValue + + # mode == 'aggregate' + descendants = [f for f in allFds if _fdsIsAncestor(rec, f)] + if not descendants: + return walkValue + + subtreeValues = set() + subtreeValues.add(_normaliseForComparison(walkValue)) + for desc in descendants: + descEffective = _resolveWalkValueFds(desc, flag, allFds) + subtreeValues.add(_normaliseForComparison(descEffective)) + if len(subtreeValues) > 1: + recId = _getRecordValue(rec, "id") + descId = _getRecordValue(desc, "id") + descOwnVal = _getRecordValue(desc, flag) + logger.info( + "FDS aggregate MIXED for rec=%s flag=%s: walkValue=%s, " + "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s", + recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues, + ) + return "mixed" + return walkValue + + +def _resolveWalkValueFds(rec: Dict[str, Any], flag: str, allFds: List[Dict[str, Any]]) -> Any: + """Core walk resolution for FDS.""" own = _getRecordValue(rec, flag) if _isExplicit(own): return own - workspaceFds: List[Dict[str, Any]] = list(sameWorkspaceFds) - ancestors = [a for a in workspaceFds if _fdsIsAncestor(a, rec)] + ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)] ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) for ancestor in ancestors: val = _getRecordValue(ancestor, flag) @@ -276,27 +440,32 @@ def getEffectiveFlagFds( return _flagDefault(flag) +# --------------------------------------------------------------------------- +# FeatureDataSource: cascadeResetDescendantsFds (bottom-up) +# --------------------------------------------------------------------------- + def cascadeResetDescendantsFds( rootIf: Any, parentRec: Dict[str, Any], flag: str, -) -> int: +) -> List[str]: """Reset explicit `flag` to NULL on every descendant FDS of `parentRec`. - Only the targeted flag is reset; other flags on descendants are untouched. - Returns the number of records updated. + Reset order: bottom-up (deepest first) for crash safety. + Returns list of reset record IDs in bottom-up order. """ - if flag not in ("neutralize", "scope"): + if flag not in _INHERITABLE_FDS_FLAGS: raise ValueError(f"Unknown inheritable FDS flag: {flag}") from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource workspaceInstanceId = _getRecordValue(parentRec, "workspaceInstanceId") if not workspaceInstanceId: - return 0 + return [] siblings = rootIf.db.getRecordset( FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId} ) - affected = 0 + + toReset: List[Tuple[int, str]] = [] for sib in siblings: if not _fdsIsAncestor(parentRec, sib): continue @@ -304,39 +473,159 @@ def cascadeResetDescendantsFds( if not _isExplicit(sibVal): continue sibId = _getRecordValue(sib, "id") + toReset.append((_fdsDepth(sib), sibId)) + + # Sort deepest first (bottom-up) + toReset.sort(key=lambda x: x[0], reverse=True) + + resetIds: List[str] = [] + for _, sibId in toReset: try: rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None}) - affected += 1 + resetIds.append(sibId) except Exception as exc: logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc) - if affected: + + if resetIds: logger.info( - "FDS cascade-reset %s on %d descendants of FDS (workspaceInstanceId=%s, kind=%s)", - flag, affected, workspaceInstanceId, _fdsClassify(parentRec), + "FDS cascade-reset %s on %d descendants of FDS %s (bottom-up)", + flag, len(resetIds), _getRecordValue(parentRec, "id"), ) - return affected + return resetIds -def buildEffectiveByConnection( - dataSources: Iterable[Dict[str, Any]], - flag: str, -) -> Dict[str, Any]: - """Pre-compute the effective value of `flag` for every DataSource id. +# --------------------------------------------------------------------------- +# FeatureDataSource: collectAncestorChainFds +# --------------------------------------------------------------------------- - Useful for batch operations (walker, route DTOs) that touch many records - at once. O(N²) in the worst case but N is bounded per connection. +def collectAncestorChainFds( + rec: Dict[str, Any], + sameWorkspaceFds: Iterable[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return ancestor chain of `rec` FDS (nearest-first). + + Exposed for PATCH endpoints to compute updatedAncestors. """ - if flag not in _INHERITABLE_FLAGS: - raise ValueError(f"Unknown inheritable flag: {flag}") - bySourceType: Dict[Tuple[str, str], List[Dict[str, Any]]] = {} - for ds in dataSources: - connId = _getRecordValue(ds, "connectionId") or "" - srcType = _getRecordValue(ds, "sourceType") or "" - bySourceType.setdefault((connId, srcType), []).append(ds) + allFds = list(sameWorkspaceFds) + ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)] + ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) + return ancestors + +# --------------------------------------------------------------------------- +# FeatureDataSource: buildEffectiveByWorkspaceFds +# --------------------------------------------------------------------------- + +def buildEffectiveByWorkspaceFds( + fdses: Iterable[Dict[str, Any]], + flag: str, + mode: Mode = "walk", +) -> Dict[str, Any]: + """Pre-compute the effective value of `flag` for every FDS id.""" + if flag not in _INHERITABLE_FDS_FLAGS: + raise ValueError(f"Unknown inheritable FDS flag: {flag}") + allFds = list(fdses) out: Dict[str, Any] = {} - for group in bySourceType.values(): - for rec in group: - recId = _getRecordValue(rec, "id") - out[recId] = getEffectiveFlag(rec, flag, group) + for rec in allFds: + recId = _getRecordValue(rec, "id") + out[recId] = getEffectiveFlagFds(rec, flag, allFds, mode=mode) return out + + +# --------------------------------------------------------------------------- +# Bulk resolve: effective flags for arbitrary paths (even without DB record) +# --------------------------------------------------------------------------- + +def resolveEffectiveForPath( + connectionId: str, + sourceType: str, + path: str, + allDs: List[Dict[str, Any]], + mode: Mode = "aggregate", +) -> Dict[str, Any]: + """Resolve effective flags for ANY (connectionId, sourceType, path) tuple. + + Works whether or not a DataSource record exists for this exact path. + Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled. + """ + normPath = _normalisePath(path) + exactRecord = None + for ds in allDs: + if ( + _getRecordValue(ds, "connectionId") == connectionId + and _getRecordValue(ds, "sourceType") == sourceType + and _normalisePath(_getRecordValue(ds, "path")) == normPath + ): + exactRecord = ds + break + + if exactRecord: + return { + "effectiveNeutralize": getEffectiveFlag(exactRecord, "neutralize", allDs, mode=mode), + "effectiveScope": getEffectiveFlag(exactRecord, "scope", allDs, mode=mode), + "effectiveRagIndexEnabled": getEffectiveFlag(exactRecord, "ragIndexEnabled", allDs, mode=mode), + } + + virtualRec = { + "id": "__virtual__", + "connectionId": connectionId, + "sourceType": sourceType, + "path": normPath, + "neutralize": None, + "scope": None, + "ragIndexEnabled": None, + } + return { + "effectiveNeutralize": _resolveWalkValue(virtualRec, "neutralize", allDs), + "effectiveScope": _resolveWalkValue(virtualRec, "scope", allDs), + "effectiveRagIndexEnabled": _resolveWalkValue(virtualRec, "ragIndexEnabled", allDs), + } + + +def resolveEffectiveForFds( + featureInstanceId: str, + tableName: str, + recordFilter: Optional[Dict[str, str]], + allFds: List[Dict[str, Any]], + mode: Mode = "aggregate", +) -> Dict[str, Any]: + """Resolve effective flags for ANY FDS tuple (even without DB record). + + `allFds` is pre-scoped to a single workspace (loaded with + workspaceInstanceId filter). Within that set, the coordinate is + featureInstanceId + tableName + recordFilter. + + Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled. + """ + exactRecord = None + for fds in allFds: + if _getRecordValue(fds, "featureInstanceId") != featureInstanceId: + continue + if (_getRecordValue(fds, "tableName") or "") != tableName: + continue + fdsFilter = _getRecordValue(fds, "recordFilter") + if fdsFilter == recordFilter: + exactRecord = fds + break + + if exactRecord: + return { + "effectiveNeutralize": getEffectiveFlagFds(exactRecord, "neutralize", allFds, mode=mode), + "effectiveScope": getEffectiveFlagFds(exactRecord, "scope", allFds, mode=mode), + "effectiveRagIndexEnabled": getEffectiveFlagFds(exactRecord, "ragIndexEnabled", allFds, mode=mode), + } + + virtualRec = { + "id": "__virtual__", + "featureInstanceId": featureInstanceId, + "tableName": tableName, + "recordFilter": recordFilter, + "neutralize": None, + "scope": None, + "ragIndexEnabled": None, + } + return { + "effectiveNeutralize": _resolveWalkValueFds(virtualRec, "neutralize", allFds), + "effectiveScope": _resolveWalkValueFds(virtualRec, "scope", allFds), + "effectiveRagIndexEnabled": _resolveWalkValueFds(virtualRec, "ragIndexEnabled", allFds), + } diff --git a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py index 6698e164..01c585d8 100644 --- a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py +++ b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py @@ -147,7 +147,7 @@ class KnowledgeService: else getattr(existing, "status", "") ) or "" if existingMeta.get("hash") == contentHash and existingStatus == "indexed": - logger.info( + logger.debug( "ingestion.skipped.duplicate sourceKind=%s sourceId=%s hash=%s", job.sourceKind, job.sourceId, contentHash[:12], extra={ diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py index 618a9965..be059eef 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py @@ -431,6 +431,15 @@ def registerKnowledgeIngestionConsumer() -> None: callbackRegistry.register("connection.established", _onConnectionEstablished) callbackRegistry.register("connection.revoked", _onConnectionRevoked) registerJobHandler(BOOTSTRAP_JOB_TYPE, _bootstrapJobHandler) + + from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import ( + FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler, + ) + registerJobHandler(FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler) + registerDailyResyncScheduler() _registered = True - logger.info("KnowledgeIngestionConsumer registered (established/revoked + %s handler + daily resync)", BOOTSTRAP_JOB_TYPE) + logger.info( + "KnowledgeIngestionConsumer registered (established/revoked + %s + %s handler + daily resync)", + BOOTSTRAP_JOB_TYPE, FEATURE_BOOTSTRAP_JOB_TYPE, + ) diff --git a/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py b/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py new file mode 100644 index 00000000..aa81d929 --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Feature-data RAG bootstrap: indexes FeatureDataSource rows into the knowledge store. + +Analogous to connection.bootstrap for external connections (Google, Microsoft), +this handler reads FeatureDataSource records with ragIndexEnabled=True, queries +the underlying feature tables via FeatureDataProvider, serialises each row into +text, and feeds it through KnowledgeService.requestIngestion so the data +appears in ContentChunk embeddings for semantic RAG search. + +Job type: ``feature.bootstrap`` +Payload: ``{"workspaceInstanceId": "...", "featureDataSourceIds": [...] (optional)}`` +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +FEATURE_BOOTSTRAP_JOB_TYPE = "feature.bootstrap" + + +def _loadRagEnabledFds(workspaceInstanceId: str, featureDataSourceIds: Optional[List[str]] = None): + """Load FeatureDataSource rows whose effective ragIndexEnabled is True. + + Returns dicts with resolved flags so downstream code can read them directly. + """ + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + + rootIf = getRootInterface() + allFds = rootIf.db.getRecordset( + FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId} + ) + resolved = [] + for fds in allFds: + tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or "" + fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or "" + if tblName == "*" or not tblName or not fCode: + continue + effRag = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate") + if effRag is not True: + continue + row = dict(fds) if isinstance(fds, dict) else {**fds.__dict__} + row["_effectiveNeutralize"] = getEffectiveFlagFds(fds, "neutralize", allFds, mode="aggregate") + row["_effectiveScope"] = getEffectiveFlagFds(fds, "scope", allFds, mode="aggregate") or "featureInstance" + row["ragIndexEnabled"] = True + resolved.append(row) + + if featureDataSourceIds: + idSet = set(featureDataSourceIds) + resolved = [r for r in resolved if r.get("id") in idSet] + return resolved + + +def _serializeRowToText(row: Dict[str, Any], neutralizeFields: Optional[List[str]] = None) -> str: + """Convert a feature-table row into readable text for embedding. + + Skips internal fields (starting with '_' or 'sys') and produces + ``key: value`` lines that embed well semantically. + """ + neutralizeSet = set(neutralizeFields or []) + lines = [] + for key, value in row.items(): + if key.startswith("_") or key.startswith("sys"): + continue + if key == "id": + continue + if value is None or value == "" or value == []: + continue + if key in neutralizeSet: + value = "[REDACTED]" + elif isinstance(value, (dict, list)): + value = json.dumps(value, ensure_ascii=False, default=str) + else: + value = str(value) + lines.append(f"{key}: {value}") + return "\n".join(lines) + + +def _getFeatureDbConnector(featureCode: str): + """Create a lightweight DB connector to the feature database.""" + from modules.connectors.connectorDbPostgre import DatabaseConnector + from modules.shared.configuration import APP_CONFIG + + dbName = f"poweron_{featureCode.lower()}" + return DatabaseConnector( + dbHost=APP_CONFIG.get("DB_HOST", "localhost"), + dbDatabase=dbName, + dbUser=APP_CONFIG.get("DB_USER"), + dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), + dbPort=int(APP_CONFIG.get("DB_PORT", 5432)), + userId="system.feature_bootstrap", + ) + + +async def _featureBootstrapHandler( + job: Dict[str, Any], + progressCb, +) -> Dict[str, Any]: + """Walk RAG-enabled FeatureDataSources and index their rows.""" + payload = job.get("payload") or {} + workspaceInstanceId = payload.get("workspaceInstanceId") + featureDataSourceIds = payload.get("featureDataSourceIds") + if not workspaceInstanceId: + raise ValueError("feature.bootstrap requires payload.workspaceInstanceId") + + progressCb(5, messageKey="Feature-Datenquellen werden geladen...") + + fdsList = _loadRagEnabledFds(workspaceInstanceId, featureDataSourceIds) + if not fdsList: + logger.info( + "feature.bootstrap.skipped — no rag-enabled FDS for workspace %s", + workspaceInstanceId, + ) + return {"workspaceInstanceId": workspaceInstanceId, "skipped": True, "reason": "no_rag_enabled_fds"} + + from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider + from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob + from modules.serviceCenter.context import ServiceCenterContext + from modules.serviceCenter import getService + from modules.security.rootAccess import getRootUser + + totalIndexed = 0 + totalSkipped = 0 + totalFailed = 0 + fdsResults = [] + + for fdsIdx, fds in enumerate(fdsList): + fdsId = fds.get("id", "") + featureCode = fds.get("featureCode", "") + tableName = fds.get("tableName", "") + featureInstanceId = fds.get("featureInstanceId", "") + mandateId = fds.get("mandateId", "") + neutralizeFields = fds.get("neutralizeFields") or [] + recordFilter = fds.get("recordFilter") or {} + effectiveScope = fds.get("_effectiveScope", "featureInstance") + effectiveNeutralize = bool(fds.get("_effectiveNeutralize", False)) + + progressPct = 5 + int(90 * fdsIdx / len(fdsList)) + progressCb( + progressPct, + messageKey="Indexiere {table} ({n}/{total})...", + messageParams={"table": tableName, "n": fdsIdx + 1, "total": len(fdsList)}, + ) + + if not featureCode or not tableName or not featureInstanceId: + logger.warning("feature.bootstrap: skipping FDS %s — missing featureCode/tableName/fiId", fdsId) + continue + + try: + dbConnector = _getFeatureDbConnector(featureCode) + provider = FeatureDataProvider(dbConnector) + + rootUser = getRootUser() + ctx = ServiceCenterContext( + user=rootUser, + mandate_id=mandateId, + feature_instance_id=workspaceInstanceId, + ) + knowledgeService = getService("knowledge", ctx) + + extraFilters = [ + {"field": k, "op": "=", "value": v} + for k, v in recordFilter.items() + ] if recordFilter else None + + batchSize = 200 + offset = 0 + fdsIndexed = 0 + fdsSkipped = 0 + fdsFailed = 0 + + while True: + result = provider.browseTable( + tableName=tableName, + featureInstanceId=featureInstanceId, + mandateId=mandateId, + limit=batchSize, + offset=offset, + extraFilters=extraFilters, + ) + rows = result.get("rows", []) + if not rows: + break + + for row in rows: + rowId = row.get("id", "") + if not rowId: + continue + + textContent = _serializeRowToText(row, neutralizeFields if effectiveNeutralize else None) + if not textContent.strip(): + fdsSkipped += 1 + continue + + contentVersion = str(row.get("sysUpdatedAt") or row.get("sysCreatedAt") or "") + + ingestionJob = IngestionJob( + sourceKind="feature_record", + sourceId=f"{workspaceInstanceId}:{tableName}:{rowId}", + fileName=f"{tableName}-{rowId}", + mimeType="application/vnd.poweron.feature-record+json", + userId=fds.get("userId") or "system", + featureInstanceId=workspaceInstanceId, + mandateId=mandateId, + contentObjects=[{ + "contentType": "text", + "data": textContent, + "contextRef": { + "table": tableName, + "featureCode": featureCode, + "featureInstanceId": featureInstanceId, + "rowId": rowId, + }, + "contentObjectId": f"{tableName}:{rowId}", + }], + structure={"sourceTable": tableName, "featureCode": featureCode}, + contentVersion=contentVersion, + provenance={ + "featureDataSourceId": fdsId, + "tableName": tableName, + "featureCode": featureCode, + "featureInstanceId": featureInstanceId, + }, + neutralize=effectiveNeutralize, + ) + + try: + handle = await knowledgeService.requestIngestion(ingestionJob) + if handle.status == "failed": + fdsFailed += 1 + logger.warning( + "feature.bootstrap: ingestion failed fds=%s table=%s row=%s error=%s", + fdsId, tableName, rowId, handle.error, + ) + elif handle.status == "duplicate": + fdsSkipped += 1 + else: + fdsIndexed += 1 + except Exception as ingErr: + fdsFailed += 1 + logger.error( + "feature.bootstrap: ingestion error fds=%s row=%s: %s", + fdsId, rowId, ingErr, + ) + + offset += batchSize + if len(rows) < batchSize: + break + + totalIndexed += fdsIndexed + totalSkipped += fdsSkipped + totalFailed += fdsFailed + + fdsResults.append({ + "featureDataSourceId": fdsId, + "tableName": tableName, + "featureCode": featureCode, + "indexed": fdsIndexed, + "skippedDuplicate": fdsSkipped, + "failed": fdsFailed, + }) + + except Exception as fdsErr: + logger.error( + "feature.bootstrap: error processing FDS %s (%s.%s): %s", + fdsId, featureCode, tableName, fdsErr, exc_info=True, + ) + fdsResults.append({ + "featureDataSourceId": fdsId, + "tableName": tableName, + "featureCode": featureCode, + "error": str(fdsErr), + }) + + progressCb(100, messageKey="Feature-Daten-Sync abgeschlossen.") + + return { + "workspaceInstanceId": workspaceInstanceId, + "indexed": totalIndexed, + "skippedDuplicate": totalSkipped, + "failed": totalFailed, + "dataSources": fdsResults, + } diff --git a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py b/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py deleted file mode 100644 index 0deae777..00000000 --- a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2025 Patrick Motsch -# All rights reserved. -"""DEPRECATED: Use `_inheritFlags.getEffectiveFlag()` directly. - -Thin shim to the new cascade-inherit helper. Kept so external callers don't -break on import — internal walkers consume pre-resolved dicts via -`_loadRagEnabledDataSources`. -""" - -from __future__ import annotations - -from typing import Any, Dict, List - -from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag - - -def resolveEffectiveNeutralize( - ds: Dict[str, Any], - allDataSources: List[Dict[str, Any]], -) -> bool: - """DEPRECATED: use `getEffectiveFlag(ds, 'neutralize', allDataSources)`.""" - value = getEffectiveFlag(ds, "neutralize", allDataSources) - return bool(value) - - -def resolveEffectiveRagIndexEnabled( - ds: Dict[str, Any], - allDataSources: List[Dict[str, Any]], -) -> bool: - """DEPRECATED: use `getEffectiveFlag(ds, 'ragIndexEnabled', allDataSources)`.""" - value = getEffectiveFlag(ds, "ragIndexEnabled", allDataSources) - return bool(value) diff --git a/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py index 8e65fd0f..41d9d458 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py +++ b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py @@ -15,8 +15,9 @@ up with "Job stuck at 10% for 10h" zombies. These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs on a worker thread so the loop stays responsive. Every wrapped call also -emits a short start/done log line, so when something hangs we know the -exact item that caused it (path, size, mime). +emits start/done log lines at DEBUG so normal INFO logs stay quiet; for +stuck-job triage, enable DEBUG for this module — the last +``walker.item.start`` before a hang still pinpoints the item (path, size, mime). """ from __future__ import annotations @@ -48,7 +49,7 @@ async def downloadWithTimeout( used in log messages so we can pinpoint the offending item in case of a hang or timeout. """ - logger.info("walker.download.start %s timeout=%ds", label, timeoutSeconds) + logger.debug("walker.download.start %s timeout=%ds", label, timeoutSeconds) try: result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds) logger.debug("walker.download.done %s", label) @@ -71,7 +72,7 @@ async def extractWithTimeout( keep running until the process exits — but at least the walker proceeds to the next item instead of freezing forever. """ - logger.info("walker.extract.start %s timeout=%ds", label, timeoutSeconds) + logger.debug("walker.extract.start %s timeout=%ds", label, timeoutSeconds) try: result = await asyncio.wait_for( asyncio.to_thread(syncFn, *args), @@ -102,15 +103,15 @@ async def ingestWithTimeout( def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None: - """Log that processing of one item is about to begin. + """Log that processing of one item is about to begin (DEBUG). When the worker hangs, the LAST `walker.item.start` line in the log - points to the exact item that caused the freeze. This is the single - most valuable diagnostic for stuck-job triage. + points to the exact item that caused the freeze. Enable DEBUG for this + module during triage. """ parts = [f"walker.item.start service={service} path={label}"] if sizeBytes is not None: parts.append(f"size={sizeBytes}") if mime: parts.append(f"mime={mime}") - logger.info(" ".join(parts)) + logger.debug(" ".join(parts)) diff --git a/scripts/script_migrate_user_uid.py b/scripts/script_migrate_user_uid.py new file mode 100644 index 00000000..07f9b443 --- /dev/null +++ b/scripts/script_migrate_user_uid.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +"""One-time migration: Reassign all DB references from an old user UID to a new UID. + +When a user is re-created in PORTA (same username, new UUID), all existing records +still reference the old UUID. This script scans ALL registered databases and tables +for VARCHAR columns containing the old UID and updates them to the new UID. + +Affected columns include: + - sysCreatedBy / sysModifiedBy (on every table via PowerOnModel) + - userId, revokedBy, createdByUserId, publishedBy, triggeredBy, assignedTo, etc. + +The script auto-detects the new UID from the UserInDB table by username. + +Usage: + # Dry-run (default) — shows what would change, no writes: + python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid + + # Execute for real: + python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid --execute +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import List, Optional, Tuple + +scriptPath = Path(__file__).resolve() +gatewayPath = scriptPath.parent.parent +sys.path.insert(0, str(gatewayPath)) +os.chdir(str(gatewayPath)) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + +import psycopg2 +import psycopg2.extras +from modules.shared.configuration import APP_CONFIG + + +ALL_DATABASES = [ + "poweron_app", + "poweron_chat", + "poweron_management", + "poweron_knowledge", + "poweron_billing", + "poweron_workspace", + "poweron_graphicaleditor", + "poweron_chatbot", + "poweron_trustee", + "poweron_commcoach", + "poweron_neutralization", + "poweron_realestate", + "poweron_teamsbot", +] + + +def _getConnection(dbName: str): + return psycopg2.connect( + host=APP_CONFIG.get("DB_HOST", "localhost"), + port=int(APP_CONFIG.get("DB_PORT", "5432")), + database=dbName, + user=APP_CONFIG.get("DB_USER"), + password=APP_CONFIG.get("DB_PASSWORD_SECRET"), + client_encoding="utf8", + ) + + +def _getTablesInDb(conn) -> List[str]: + with conn.cursor() as cur: + cur.execute(""" + SELECT table_name FROM information_schema.tables + WHERE table_schema = 'public' + AND table_type = 'BASE TABLE' + AND table_name NOT LIKE '\\_%%' + ORDER BY table_name + """) + return [row[0] for row in cur.fetchall()] + + +def _getVarcharColumns(conn, tableName: str) -> List[str]: + """Get all VARCHAR/TEXT columns for a table (potential user-ID carriers).""" + with conn.cursor() as cur: + cur.execute(""" + SELECT column_name FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = %s + AND data_type IN ('character varying', 'text') + ORDER BY ordinal_position + """, (tableName,)) + return [row[0] for row in cur.fetchall()] + + +def _countMatches(conn, tableName: str, columnName: str, oldUid: str) -> int: + with conn.cursor() as cur: + cur.execute( + f'SELECT COUNT(*) FROM "{tableName}" WHERE "{columnName}" = %s', + (oldUid,), + ) + return cur.fetchone()[0] + + +def _updateColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int: + with conn.cursor() as cur: + cur.execute( + f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE "{columnName}" = %s', + (newUid, oldUid), + ) + return cur.rowcount + + +def _lookupNewUid(username: str) -> Optional[str]: + """Find the current UID for a username in poweron_app.UserInDB.""" + conn = _getConnection("poweron_app") + try: + with conn.cursor() as cur: + cur.execute( + 'SELECT "id" FROM "UserInDB" WHERE "username" = %s', + (username,), + ) + row = cur.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def _scanJsonbForUid(conn, tableName: str, columnName: str, oldUid: str) -> int: + """Count JSONB fields that contain the old UID as a text value anywhere.""" + with conn.cursor() as cur: + cur.execute( + f"""SELECT COUNT(*) FROM "{tableName}" + WHERE "{columnName}"::text LIKE %s""", + (f"%{oldUid}%",), + ) + return cur.fetchone()[0] + + +def _updateJsonbColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int: + """Replace old UID inside JSONB columns using text replacement.""" + with conn.cursor() as cur: + cur.execute( + f"""UPDATE "{tableName}" + SET "{columnName}" = REPLACE("{columnName}"::text, %s, %s)::jsonb + WHERE "{columnName}"::text LIKE %s""", + (oldUid, newUid, f"%{oldUid}%"), + ) + return cur.rowcount + + +def _getJsonbColumns(conn, tableName: str) -> List[str]: + """Get all JSONB columns for a table.""" + with conn.cursor() as cur: + cur.execute(""" + SELECT column_name FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = %s + AND data_type = 'jsonb' + ORDER BY ordinal_position + """, (tableName,)) + return [row[0] for row in cur.fetchall()] + + +def migrate(username: str, oldUid: str, execute: bool = False): + newUid = _lookupNewUid(username) + if not newUid: + logger.error(f"User '{username}' not found in UserInDB. Cannot determine new UID.") + sys.exit(1) + + if newUid == oldUid: + logger.error(f"Old UID and new UID are identical ({oldUid}). Nothing to migrate.") + sys.exit(1) + + logger.info(f"Migration: user '{username}'") + logger.info(f" Old UID: {oldUid}") + logger.info(f" New UID: {newUid}") + logger.info(f" Mode: {'EXECUTE' if execute else 'DRY-RUN'}") + logger.info("") + + totalUpdated = 0 + findings: List[Tuple[str, str, str, int]] = [] + + for dbName in ALL_DATABASES: + try: + conn = _getConnection(dbName) + except Exception as e: + logger.warning(f" Cannot connect to {dbName}: {e}") + continue + + try: + conn.autocommit = False + tables = _getTablesInDb(conn) + + for tableName in tables: + varcharCols = _getVarcharColumns(conn, tableName) + for col in varcharCols: + count = _countMatches(conn, tableName, col, oldUid) + if count > 0: + findings.append((dbName, tableName, col, count)) + if execute: + updated = _updateColumn(conn, tableName, col, oldUid, newUid) + totalUpdated += updated + logger.info(f" [UPDATED] {dbName}.{tableName}.{col}: {updated} rows") + else: + logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col}: {count} rows would be updated") + + jsonbCols = _getJsonbColumns(conn, tableName) + for col in jsonbCols: + count = _scanJsonbForUid(conn, tableName, col, oldUid) + if count > 0: + findings.append((dbName, tableName, f"{col} (JSONB)", count)) + if execute: + _updateJsonbColumn(conn, tableName, col, oldUid, newUid) + totalUpdated += count + logger.info(f" [UPDATED] {dbName}.{tableName}.{col} (JSONB): {count} rows") + else: + logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col} (JSONB): {count} rows would be updated") + + if execute: + conn.commit() + else: + conn.rollback() + except Exception as e: + conn.rollback() + logger.error(f" Error processing {dbName}: {e}") + finally: + conn.close() + + logger.info("") + logger.info("=" * 70) + logger.info("SUMMARY") + logger.info("=" * 70) + if not findings: + logger.info(" No references to old UID found in any database.") + else: + logger.info(f" Found {len(findings)} column(s) with references to old UID:") + for dbName, tableName, col, count in findings: + logger.info(f" {dbName}.{tableName}.{col}: {count} rows") + logger.info("") + if execute: + logger.info(f" Total rows updated: {totalUpdated}") + else: + logger.info(f" Total rows that would be updated: {sum(c for _, _, _, c in findings)}") + logger.info("") + logger.info(" To apply changes, re-run with --execute") + + +def main(): + parser = argparse.ArgumentParser( + description="Migrate all DB references from old user UID to new UID." + ) + parser.add_argument( + "--username", + required=True, + help="Username to migrate (e.g. 'patrick.helvetia'). Used to look up the new UID.", + ) + parser.add_argument( + "--old-uid", + required=True, + help="The old UUID that is orphaned in the database.", + ) + parser.add_argument( + "--execute", + action="store_true", + default=False, + help="Actually perform the migration. Without this flag, only a dry-run is done.", + ) + args = parser.parse_args() + + migrate(username=args.username, oldUid=args.old_uid, execute=args.execute) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py index 57094760..5fb505d7 100644 --- a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py +++ b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py @@ -30,6 +30,7 @@ import psycopg2.errors from modules.connectors.connectorDbPostgre import ( DatabaseConnector, DatabaseQueryError, + _stripNulBytesFromStr, ) @@ -164,3 +165,12 @@ class TestGetRecordFailLoud: assert excinfo.value.table == "DummyTable" conn.rollback.assert_called_once() + + +class TestStripNulBytesFromStr: + def test_removesNul(self): + assert _stripNulBytesFromStr("a\x00b") == "ab" + + def test_passthroughNonStr(self): + assert _stripNulBytesFromStr(None) is None + assert _stripNulBytesFromStr(7) == 7 diff --git a/tests/unit/services/test_buildTree.py b/tests/unit/services/test_buildTree.py new file mode 100644 index 00000000..5a2bacb4 --- /dev/null +++ b/tests/unit/services/test_buildTree.py @@ -0,0 +1,359 @@ +"""Unit tests for the generic UDB tree builder. + +Verifies key encoding/decoding and that children for parent keys with +existing handlers (top-level, conn, mgrp, feat) are produced with the +correct effective-flag triplet. +""" + +from __future__ import annotations + +import asyncio +import unittest +from unittest.mock import MagicMock, patch + +from modules.serviceCenter.services.serviceKnowledge import _buildTree + + +class TestKeyCoding(unittest.TestCase): + def test_encode_decode_roundtrip(self): + key = _buildTree._encode("ds", "conn-1", "sharepointFolder", "/sites/x") + kind, parts = _buildTree._decode(key) + self.assertEqual(kind, "ds") + self.assertEqual(parts, ["conn-1", "sharepointFolder", "/sites/x"]) + + def test_top_level_kinds(self): + self.assertEqual(_buildTree._decode("conn|abc")[0], "conn") + self.assertEqual(_buildTree._decode("mgrp|m1")[0], "mgrp") + self.assertEqual(_buildTree._decode("feat|m1|trustee|fi-1")[1], ["m1", "trustee", "fi-1"]) + + +class TestEffectiveTriplets(unittest.TestCase): + def test_ds_triplet_no_record_returns_defaults(self): + result = _buildTree._effectiveTripletDs("c", "msft", "/", []) + self.assertEqual(result, { + "effectiveNeutralize": False, + "effectiveScope": "personal", + "effectiveRagIndexEnabled": False, + }) + + def test_ds_triplet_inherits_from_root(self): + root = { + "id": "r", "connectionId": "c", "sourceType": "msft", "path": "/", + "neutralize": True, "scope": "mandate", "ragIndexEnabled": True, + } + result = _buildTree._effectiveTripletDs("c", "sharepointFolder", "/sites/x", [root]) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_fds_triplet_inherits_from_workspace_wildcard(self): + ws = { + "id": "ws", "workspaceInstanceId": "ws-inst", "featureInstanceId": "fi1", + "tableName": "*", "recordFilter": None, "neutralize": True, + "scope": "mandate", "ragIndexEnabled": True, + } + result = _buildTree._effectiveTripletFds("fi1", "Pos", None, [ws]) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + +class TestRecordLookup(unittest.TestCase): + def test_finds_ds_record_by_normalised_path(self): + rec = {"id": "x", "connectionId": "c", "sourceType": "msft", "path": "/folder"} + self.assertEqual(_buildTree._findDsRecord([rec], "c", "msft", "/folder/").get("id"), "x") + self.assertIsNone(_buildTree._findDsRecord([rec], "c", "msft", "/other")) + + def test_finds_fds_record_with_matching_filter(self): + rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "Pos", "recordFilter": {"id": "5"}} + self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "5"}).get("id"), "f") + self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "99"})) + + def test_fds_record_with_none_filter_matches_only_none(self): + rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "*", "recordFilter": None} + self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "*", None).get("id"), "f") + self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "*", {"id": "1"})) + + +class TestGetChildrenForParents(unittest.TestCase): + """End-to-end orchestrator test with mocked dependencies.""" + + def _runAsync(self, coro): + return asyncio.get_event_loop().run_until_complete(coro) + + def test_unknown_parent_key_returns_empty_list(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + mockRoot.return_value = rootIf + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", ["bogus|key"], ctx) + ) + self.assertEqual(result["bogus|key"], []) + + def test_top_level_emits_personal_root_first(self): + """Top-level emits personalRoot first, then mandate-group nodes inline.""" + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + rootIf.getUserMandates.return_value = [] + mockRoot.return_value = rootIf + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + children = result["__root__"] + self.assertGreaterEqual(len(children), 1) + personalRoot = children[0] + self.assertEqual(personalRoot["key"], "personalRoot") + self.assertEqual(personalRoot["kind"], "synthRoot") + self.assertIsNone(personalRoot["parentKey"]) + self.assertTrue(personalRoot["hasChildren"]) + self.assertTrue(personalRoot["defaultExpanded"]) + + +class TestTopLevelLayout(unittest.TestCase): + """Tests for the flat top-level layout (personalRoot + mandate groups).""" + + def _runAsync(self, coro): + return asyncio.get_event_loop().run_until_complete(coro) + + def test_personal_root_carries_neutral_default_triplet(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + rootIf.getUserMandates.return_value = [] + mockRoot.return_value = rootIf + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + personalRoot = result["__root__"][0] + self.assertFalse(personalRoot["effectiveNeutralize"]) + self.assertEqual(personalRoot["effectiveScope"], "personal") + self.assertFalse(personalRoot["effectiveRagIndexEnabled"]) + self.assertFalse(personalRoot["supportsRag"]) + self.assertFalse(personalRoot["canBeAdded"]) + self.assertIsNone(personalRoot["dataSourceId"]) + self.assertIsNone(personalRoot["modelType"]) + + def test_personal_root_emits_active_connection_with_correct_parent(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.serviceCenter.getService") as mockGetService: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + mockRoot.return_value = rootIf + + chatService = MagicMock() + chatService.getUserConnections.return_value = [{ + "id": "conn-1", + "status": "active", + "authority": "msft", + "externalEmail": "user@example.com", + }] + mockGetService.return_value = chatService + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx) + ) + children = result["personalRoot"] + self.assertEqual(len(children), 1) + self.assertEqual(children[0]["key"], "conn|conn-1") + self.assertEqual(children[0]["kind"], "connection") + self.assertEqual(children[0]["parentKey"], "personalRoot") + self.assertEqual(children[0]["label"], "user@example.com") + self.assertTrue(children[0]["supportsRag"]) + + def test_personal_root_skips_inactive_connection(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.serviceCenter.getService") as mockGetService: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + mockRoot.return_value = rootIf + + chatService = MagicMock() + chatService.getUserConnections.return_value = [ + {"id": "c1", "status": "active", "authority": "msft", "externalEmail": "a"}, + {"id": "c2", "status": "expired", "authority": "google", "externalEmail": "b"}, + ] + mockGetService.return_value = chatService + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx) + ) + self.assertEqual(len(result["personalRoot"]), 1) + self.assertEqual(result["personalRoot"][0]["connectionId"], "c1") + + def test_mandate_groups_emitted_inline_at_top_level(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + userMandate = MagicMock() + userMandate.mandateId = "m1" + rootIf.getUserMandates.return_value = [userMandate] + featureInst = MagicMock() + featureInst.id = "fi-1" + featureInst.featureCode = "trustee" + featureInst.enabled = True + rootIf.getFeatureInstancesByMandate.return_value = [featureInst] + featureAccess = MagicMock() + featureAccess.enabled = True + rootIf.getFeatureAccess.return_value = featureAccess + mockRoot.return_value = rootIf + + catalog = MagicMock() + catalog.getFeaturesWithDataObjects.return_value = ["trustee"] + mockCatalog.return_value = catalog + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = None + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + children = result["__root__"] + byKey = {c["key"]: c for c in children} + self.assertIn("personalRoot", byKey) + self.assertIn("mgrp|m1", byKey) + mgroup = byKey["mgrp|m1"] + self.assertEqual(mgroup["kind"], "mandateGroup") + self.assertIsNone(mgroup["parentKey"]) + self.assertEqual(mgroup["mandateId"], "m1") + self.assertTrue(mgroup["defaultExpanded"]) + self.assertFalse(mgroup["supportsRag"]) + + def test_top_level_omits_mandates_without_data_features(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + userMandate = MagicMock() + userMandate.mandateId = "m1" + rootIf.getUserMandates.return_value = [userMandate] + rootIf.getFeatureInstancesByMandate.return_value = [] + mockRoot.return_value = rootIf + + catalog = MagicMock() + catalog.getFeaturesWithDataObjects.return_value = ["trustee"] + mockCatalog.return_value = catalog + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = None + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + keys = [c["key"] for c in result["__root__"]] + self.assertEqual(keys, ["personalRoot"]) + + def test_personal_root_listed_first_via_display_order(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + userMandate = MagicMock() + userMandate.mandateId = "m1" + rootIf.getUserMandates.return_value = [userMandate] + featureInst = MagicMock() + featureInst.id = "fi-1" + featureInst.featureCode = "trustee" + featureInst.enabled = True + rootIf.getFeatureInstancesByMandate.return_value = [featureInst] + featureAccess = MagicMock() + featureAccess.enabled = True + rootIf.getFeatureAccess.return_value = featureAccess + mockRoot.return_value = rootIf + + catalog = MagicMock() + catalog.getFeaturesWithDataObjects.return_value = ["trustee"] + mockCatalog.return_value = catalog + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = None + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + children = result["__root__"] + self.assertEqual(children[0]["key"], "personalRoot") + self.assertEqual(children[0]["displayOrder"], 0) + + +class TestFeatureTableFields(unittest.TestCase): + """Per-column field expansion under a feature data-source table.""" + + def test_emits_one_node_per_field(self): + nodes = _buildTree._featureTableFields( + parentKey="fdstbl|fi-1|TrusteePosition", + featureInstanceId="fi-1", + tableName="TrusteePosition", + fieldNames=["id", "valuta", "company"], + allFds=[], + ) + self.assertEqual(len(nodes), 3) + self.assertEqual(nodes[0]["kind"], "fdsField") + self.assertEqual(nodes[0]["fieldName"], "id") + self.assertEqual(nodes[0]["parentKey"], "fdstbl|fi-1|TrusteePosition") + self.assertEqual(nodes[0]["key"], "fdsfld|fi-1|TrusteePosition|id") + self.assertFalse(nodes[0]["hasChildren"]) + self.assertFalse(nodes[0]["supportsRag"]) + + def test_field_neutralize_inherits_from_table_blanket(self): + rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1", + "tableName": "TrusteePosition", "recordFilter": None, + "neutralize": True, "neutralizeFields": None, + "scope": None, "ragIndexEnabled": False} + nodes = _buildTree._featureTableFields( + parentKey="fdstbl|fi-1|TrusteePosition", + featureInstanceId="fi-1", + tableName="TrusteePosition", + fieldNames=["email", "company"], + allFds=[rec], + ) + self.assertTrue(nodes[0]["effectiveNeutralize"]) + self.assertTrue(nodes[1]["effectiveNeutralize"]) + + def test_field_neutralize_explicit_via_neutralize_fields(self): + rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1", + "tableName": "TrusteePosition", "recordFilter": None, + "neutralize": False, "neutralizeFields": ["email"], + "scope": None, "ragIndexEnabled": False} + nodes = _buildTree._featureTableFields( + parentKey="fdstbl|fi-1|TrusteePosition", + featureInstanceId="fi-1", + tableName="TrusteePosition", + fieldNames=["email", "company"], + allFds=[rec], + ) + byField = {n["fieldName"]: n for n in nodes} + self.assertTrue(byField["email"]["effectiveNeutralize"]) + self.assertFalse(byField["company"]["effectiveNeutralize"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/services/test_inheritFlags.py b/tests/unit/services/test_inheritFlags.py index b177e767..98e6fb41 100644 --- a/tests/unit/services/test_inheritFlags.py +++ b/tests/unit/services/test_inheritFlags.py @@ -1,12 +1,12 @@ """Unit tests for `_inheritFlags` cascade-inherit helpers. Verifies: -- getEffectiveFlag walks ancestors via path-prefix matching -- root default is False (or 'personal' for scope) when nothing explicit in chain -- only same-connectionId AND same-sourceType ancestors are considered -- cascadeResetDescendants only touches descendants with explicit values for THAT flag -- '/' is treated as ancestor of every non-root path -- '/foo' is NOT ancestor of '/foobar' (must require '/' separator) +- getEffectiveFlag mode='walk': walks ancestors via path-prefix matching +- getEffectiveFlag mode='aggregate': returns 'mixed' when subtree diverges +- cascadeResetDescendants: bottom-up reset returning List[str] +- cascadeResetDescendantsFds: same for FeatureDataSource +- collectAncestorChain / collectAncestorChainFds: ancestor discovery +- buildEffectiveByConnection / buildEffectiveByWorkspaceFds: batch compute """ from __future__ import annotations @@ -33,7 +33,26 @@ def _ds(idVal: str, path: str, **flags) -> dict: return base -class TestEffectiveFlag(unittest.TestCase): +def _fds(idVal: str, *, tableName: str, recordFilter=None, featureInstanceId="fi-1", **flags) -> dict: + """Build a FeatureDataSource dict fixture.""" + base = { + "id": idVal, + "workspaceInstanceId": "ws-1", + "featureInstanceId": featureInstanceId, + "tableName": tableName, + "recordFilter": recordFilter, + "neutralize": None, + "scope": None, + } + base.update(flags) + return base + + +# =========================================================================== +# DataSource: getEffectiveFlag mode='walk' +# =========================================================================== + +class TestEffectiveFlagWalk(unittest.TestCase): def test_explicit_own_value_wins(self): root = _ds("r", "/", neutralize=False) leaf = _ds("l", "/folder/sub", neutralize=True) @@ -65,7 +84,6 @@ class TestEffectiveFlag(unittest.TestCase): self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherType, leaf])) def test_path_separator_required(self): - """`/foo` must NOT be ancestor of `/foobar` (no shared `/` boundary).""" notAncestor = _ds("a", "/foo", neutralize=True) leaf = _ds("l", "/foobar") self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [notAncestor, leaf])) @@ -90,32 +108,101 @@ class TestEffectiveFlag(unittest.TestCase): _inheritFlags.getEffectiveFlag(leaf, "unknownFlag", [leaf]) def test_explicit_false_overrides_inherited_true(self): - """Explicit False on a child must NOT cascade up to True from an ancestor.""" root = _ds("r", "/", neutralize=True) leaf = _ds("l", "/folder", neutralize=False) self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf])) def test_connection_root_inherits_cross_sourcetype(self): - """Connection-root (sourceType=authority, path='/') is ancestor of all DS in that connection.""" connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) spService = _ds("sp", "/", sourceType="sharepointFolder") olService = _ds("ol", "/", sourceType="outlookFolder") - self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", [connRoot, spService, olService])) - self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", [connRoot, spService, olService])) + allDs = [connRoot, spService, olService] + self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", allDs)) + self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", allDs)) def test_same_sourcetype_ancestor_wins_over_connection_root(self): - """A same-sourceType service-root ancestor beats the connection-root.""" connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) spRoot = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) spLeaf = _ds("spl", "/sites/x", sourceType="sharepointFolder") self.assertFalse(_inheritFlags.getEffectiveFlag(spLeaf, "neutralize", [connRoot, spRoot, spLeaf])) def test_connection_root_does_not_self_inherit(self): - """Connection-root has no ancestor — does not infinite-loop on itself.""" connRoot = _ds("conn", "/", sourceType="msft") self.assertFalse(_inheritFlags.getEffectiveFlag(connRoot, "neutralize", [connRoot])) +# =========================================================================== +# DataSource: getEffectiveFlag mode='aggregate' +# =========================================================================== + +class TestEffectiveFlagAggregate(unittest.TestCase): + def test_leaf_without_descendants_returns_concrete(self): + leaf = _ds("l", "/folder", neutralize=True) + self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [leaf], mode="aggregate")) + + def test_all_descendants_same_returns_concrete(self): + root = _ds("r", "/", neutralize=True) + child1 = _ds("c1", "/a", neutralize=True) + child2 = _ds("c2", "/b") # inherits True from root + allDs = [root, child1, child2] + self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate")) + + def test_divergent_descendants_returns_mixed(self): + root = _ds("r", "/", neutralize=True) + child1 = _ds("c1", "/a", neutralize=False) + child2 = _ds("c2", "/b") # inherits True from root + allDs = [root, child1, child2] + self.assertEqual(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"), "mixed") + + def test_mixed_scope(self): + root = _ds("r", "/", scope="personal") + child1 = _ds("c1", "/a", scope="team") + child2 = _ds("c2", "/b") # inherits personal from root + allDs = [root, child1, child2] + self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "mixed") + + def test_all_scope_same_explicit_returns_concrete(self): + root = _ds("r", "/", scope="team") + child1 = _ds("c1", "/a", scope="team") + child2 = _ds("c2", "/b") # inherits team + allDs = [root, child1, child2] + self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "team") + + def test_connection_root_aggregate_cross_sourcetype(self): + connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) + spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) + olInherit = _ds("ol", "/", sourceType="outlookFolder") # inherits True + allDs = [connRoot, spExplicit, olInherit] + self.assertEqual( + _inheritFlags.getEffectiveFlag(connRoot, "neutralize", allDs, mode="aggregate"), + "mixed", + ) + + def test_mid_level_aggregate_only_considers_own_subtree(self): + root = _ds("r", "/", neutralize=True) + mid = _ds("m", "/folder", neutralize=True) + midChild = _ds("mc", "/folder/sub", neutralize=True) + sibling = _ds("s", "/other", neutralize=False) # not under mid + allDs = [root, mid, midChild, sibling] + # mid's subtree is just midChild(True) + mid(True) = uniform + self.assertTrue(_inheritFlags.getEffectiveFlag(mid, "neutralize", allDs, mode="aggregate")) + # root's subtree includes sibling(False) = mixed + self.assertEqual( + _inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"), + "mixed", + ) + + def test_walk_mode_never_returns_mixed(self): + root = _ds("r", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) + allDs = [root, child] + self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="walk")) + + +# =========================================================================== +# DataSource: cascadeResetDescendants (bottom-up, List[str]) +# =========================================================================== + class TestCascadeReset(unittest.TestCase): def _makeRootIf(self, dataSources: List[dict]): rootIf = MagicMock() @@ -127,54 +214,76 @@ class TestCascadeReset(unittest.TestCase): rootIf.db.recordModify = MagicMock(side_effect=_modify) return rootIf, modified + def test_returns_list_of_ids(self): + parent = _ds("p", "/sites", neutralize=True) + child = _ds("c1", "/sites/folder1", neutralize=False) + rootIf, _ = self._makeRootIf([parent, child]) + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + self.assertIsInstance(result, list) + self.assertEqual(result, ["c1"]) + def test_resets_only_explicit_descendants(self): parent = _ds("p", "/sites", neutralize=True) explicitChild = _ds("c1", "/sites/folder1", neutralize=False) - inheritChild = _ds("c2", "/sites/folder2") # inherit -> not touched - sibling = _ds("s", "/other", neutralize=True) # NOT a descendant + inheritChild = _ds("c2", "/sites/folder2") + sibling = _ds("s", "/other", neutralize=True) rootIf, modified = self._makeRootIf([parent, explicitChild, inheritChild, sibling]) - affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") - self.assertEqual(affected, 1) + self.assertEqual(result, ["c1"]) self.assertEqual(modified, [("c1", {"neutralize": None})]) - def test_does_not_touch_other_flags(self): - parent = _ds("p", "/sites", neutralize=True) - child = _ds("c", "/sites/sub", neutralize=False, ragIndexEnabled=True) + def test_bottom_up_order(self): + """Deepest items are reset first.""" + parent = _ds("p", "/", neutralize=True) + level1 = _ds("l1", "/a", neutralize=False) + level2 = _ds("l2", "/a/b", neutralize=False) + level3 = _ds("l3", "/a/b/c", neutralize=False) + rootIf, modified = self._makeRootIf([parent, level1, level2, level3]) + + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + + self.assertEqual(result, ["l3", "l2", "l1"]) + + def test_deep_cascade_through_null_items(self): + """null items are skipped (no DB write) but cascade continues deeper.""" + parent = _ds("p", "/", neutralize=True) + nullChild = _ds("n", "/a") # null — no write, but not a barrier + deepExplicit = _ds("d", "/a/b", neutralize=False) + rootIf, modified = self._makeRootIf([parent, nullChild, deepExplicit]) + + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + + self.assertEqual(result, ["d"]) + self.assertEqual(modified, [("d", {"neutralize": None})]) + + def test_does_not_modify_parent(self): + parent = _ds("p", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) rootIf, modified = self._makeRootIf([parent, child]) - _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") - - self.assertEqual(modified, [("c", {"neutralize": None})]) - # ragIndexEnabled and scope on the child must remain untouched. - - def test_does_not_cross_sourcetype(self): - """Non-connection-root parents stay within their sourceType for cascade.""" - parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder") - otherTypeDescendant = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder") - rootIf, modified = self._makeRootIf([parent, otherTypeDescendant]) - - affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") - - self.assertEqual(affected, 0) - self.assertEqual(modified, []) + self.assertNotIn("p", [m[0] for m in modified]) def test_connection_root_cascades_cross_sourcetype(self): - """Toggle on connection-root cascades into every explicit DS of that connection.""" connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) olInherit = _ds("ol", "/", sourceType="outlookFolder") - spLeafExplicit = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True) - rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeafExplicit]) + spLeaf = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True) + rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeaf]) - affected = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize") + result = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize") - # spExplicit and spLeafExplicit had explicit values → reset. olInherit untouched. - self.assertEqual(affected, 2) - self.assertEqual({m[0] for m in modified}, {"sp", "sp-leaf"}) - for _, fields in modified: - self.assertEqual(fields, {"neutralize": None}) + self.assertEqual(set(result), {"sp", "sp-leaf"}) + # sp-leaf is deeper, should come first + self.assertEqual(result[0], "sp-leaf") + + def test_does_not_cross_sourcetype_for_non_authority(self): + parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder") + otherType = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder") + rootIf, modified = self._makeRootIf([parent, otherType]) + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + self.assertEqual(result, []) def test_unknown_flag_raises(self): parent = _ds("p", "/", neutralize=True) @@ -183,57 +292,59 @@ class TestCascadeReset(unittest.TestCase): _inheritFlags.cascadeResetDescendants(rootIf, parent, "unknownFlag") -def _fds(idVal: str, *, tableName: str, recordFilter=None, **flags) -> dict: - """Build a FeatureDataSource dict fixture.""" - base = { - "id": idVal, - "workspaceInstanceId": "ws-1", - "tableName": tableName, - "recordFilter": recordFilter, - "neutralize": None, - "scope": None, - } - base.update(flags) - return base +# =========================================================================== +# DataSource: collectAncestorChain +# =========================================================================== + +class TestCollectAncestorChain(unittest.TestCase): + def test_returns_nearest_first(self): + root = _ds("r", "/", neutralize=True) + mid = _ds("m", "/a") + leaf = _ds("l", "/a/b") + chain = _inheritFlags.collectAncestorChain(leaf, [root, mid, leaf]) + self.assertEqual([_inheritFlags._getRecordValue(c, "id") for c in chain], ["m", "r"]) + + def test_connection_root_is_last(self): + connRoot = _ds("conn", "/", sourceType="msft") + spRoot = _ds("sp", "/", sourceType="sharepointFolder") + spLeaf = _ds("spl", "/sub", sourceType="sharepointFolder") + chain = _inheritFlags.collectAncestorChain(spLeaf, [connRoot, spRoot, spLeaf]) + ids = [_inheritFlags._getRecordValue(c, "id") for c in chain] + self.assertEqual(ids, ["sp", "conn"]) + + def test_root_has_no_ancestors(self): + root = _ds("r", "/") + chain = _inheritFlags.collectAncestorChain(root, [root]) + self.assertEqual(chain, []) -class TestFdsClassifyAndAncestry(unittest.TestCase): - def test_classify_workspace_wildcard(self): - self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="*")), "workspace") +# =========================================================================== +# DataSource: buildEffectiveByConnection +# =========================================================================== - def test_classify_table_wildcard(self): - self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="Pos")), "table") +class TestBuildEffectiveByConnection(unittest.TestCase): + def test_walk_mode(self): + root = _ds("r", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) + leaf = _ds("l", "/a/b") # inherits False from child + result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="walk") + self.assertEqual(result, {"r": True, "c": False, "l": False}) - def test_classify_record_specific(self): - rec = _fds("a", tableName="Pos", recordFilter={"id": "r-1"}) - self.assertEqual(_inheritFlags._fdsClassify(rec), "record") - - def test_workspace_is_ancestor_of_table_and_record(self): - ws = _fds("ws", tableName="*") - tbl = _fds("t", tableName="Pos") - rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) - self.assertTrue(_inheritFlags._fdsIsAncestor(ws, tbl)) - self.assertTrue(_inheritFlags._fdsIsAncestor(ws, rec)) - - def test_table_is_ancestor_of_record_same_table_only(self): - tbl = _fds("t", tableName="Pos") - recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"}) - recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}) - self.assertTrue(_inheritFlags._fdsIsAncestor(tbl, recSame)) - self.assertFalse(_inheritFlags._fdsIsAncestor(tbl, recOther)) - - def test_record_has_no_descendants(self): - rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) - tbl = _fds("t", tableName="Pos") - self.assertFalse(_inheritFlags._fdsIsAncestor(rec, tbl)) - - def test_no_cross_workspace_ancestry(self): - ws = _fds("ws", tableName="*", workspaceInstanceId="ws-A") - rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, workspaceInstanceId="ws-B") - self.assertFalse(_inheritFlags._fdsIsAncestor(ws, rec)) + def test_aggregate_mode(self): + root = _ds("r", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) + leaf = _ds("l", "/a/b") # inherits False from child + result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="aggregate") + self.assertEqual(result["r"], "mixed") + self.assertEqual(result["c"], False) + self.assertEqual(result["l"], False) -class TestFdsEffectiveFlag(unittest.TestCase): +# =========================================================================== +# FeatureDataSource: getEffectiveFlagFds +# =========================================================================== + +class TestFdsEffectiveFlagWalk(unittest.TestCase): def test_own_explicit_wins(self): ws = _fds("ws", tableName="*", neutralize=False) rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) @@ -262,9 +373,50 @@ class TestFdsEffectiveFlag(unittest.TestCase): def test_unknown_flag_raises(self): rec = _fds("r", tableName="*") with self.assertRaises(ValueError): - _inheritFlags.getEffectiveFlagFds(rec, "ragIndexEnabled", [rec]) + _inheritFlags.getEffectiveFlagFds(rec, "doesNotExist", [rec]) +class TestFdsEffectiveFlagAggregate(unittest.TestCase): + def test_leaf_without_descendants(self): + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) + self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [rec], mode="aggregate")) + + def test_all_descendants_same(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos") # inherits True + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits True + allFds = [ws, tbl, rec] + self.assertTrue(_inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate")) + + def test_divergent_descendants_returns_mixed(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl + allFds = [ws, tbl, rec] + self.assertEqual( + _inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"), + "mixed", + ) + + def test_table_aggregate_own_subtree_only(self): + ws = _fds("ws", tableName="*", neutralize=True) + tblA = _fds("tA", tableName="A", neutralize=True) + recA = _fds("rA", tableName="A", recordFilter={"id": "1"}, neutralize=True) + tblB = _fds("tB", tableName="B", neutralize=False) + allFds = [ws, tblA, recA, tblB] + # tblA subtree: all True + self.assertTrue(_inheritFlags.getEffectiveFlagFds(tblA, "neutralize", allFds, mode="aggregate")) + # ws subtree: mixed (tblB is False) + self.assertEqual( + _inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"), + "mixed", + ) + + +# =========================================================================== +# FeatureDataSource: cascadeResetDescendantsFds (bottom-up, List[str]) +# =========================================================================== + class TestFdsCascadeReset(unittest.TestCase): def _makeRootIf(self, fdses): rootIf = MagicMock() @@ -276,6 +428,14 @@ class TestFdsCascadeReset(unittest.TestCase): rootIf.db.recordModify = MagicMock(side_effect=_modify) return rootIf, modified + def test_returns_list_of_ids(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rootIf, _ = self._makeRootIf([ws, tbl]) + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") + self.assertIsInstance(result, list) + self.assertEqual(result, ["t"]) + def test_workspace_cascades_to_all_explicit_descendants(self): ws = _fds("ws", tableName="*", neutralize=True) tblExplicit = _fds("t", tableName="Pos", neutralize=False) @@ -283,10 +443,11 @@ class TestFdsCascadeReset(unittest.TestCase): recExplicit = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) rootIf, modified = self._makeRootIf([ws, tblExplicit, tblInherit, recExplicit]) - affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") - self.assertEqual(affected, 2) - self.assertEqual({m[0] for m in modified}, {"t", "r"}) + self.assertEqual(set(result), {"t", "r"}) + # record is deeper (depth 2) than table (depth 1), should come first + self.assertEqual(result[0], "r") def test_table_cascades_only_to_same_table_records(self): tbl = _fds("t", tableName="Pos", neutralize=True) @@ -294,25 +455,189 @@ class TestFdsCascadeReset(unittest.TestCase): recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}, neutralize=False) rootIf, modified = self._makeRootIf([tbl, recSame, recOther]) - affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize") + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize") - self.assertEqual(affected, 1) + self.assertEqual(result, ["r1"]) self.assertEqual(modified, [("r1", {"neutralize": None})]) def test_record_has_no_cascade(self): rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) rootIf, modified = self._makeRootIf([rec]) - affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize") - self.assertEqual(affected, 0) - self.assertEqual(modified, []) + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize") + self.assertEqual(result, []) def test_unknown_flag_raises(self): ws = _fds("ws", tableName="*", neutralize=True) rootIf, _ = self._makeRootIf([ws]) with self.assertRaises(ValueError): - _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled") + _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "doesNotExist") +# =========================================================================== +# FeatureDataSource: collectAncestorChainFds +# =========================================================================== + +class TestCollectAncestorChainFds(unittest.TestCase): + def test_record_has_table_then_workspace(self): + ws = _fds("ws", tableName="*") + tbl = _fds("t", tableName="Pos") + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + chain = _inheritFlags.collectAncestorChainFds(rec, [ws, tbl, rec]) + ids = [c["id"] for c in chain] + self.assertEqual(ids, ["t", "ws"]) + + def test_table_has_only_workspace(self): + ws = _fds("ws", tableName="*") + tbl = _fds("t", tableName="Pos") + chain = _inheritFlags.collectAncestorChainFds(tbl, [ws, tbl]) + self.assertEqual([c["id"] for c in chain], ["ws"]) + + def test_workspace_has_no_ancestors(self): + ws = _fds("ws", tableName="*") + chain = _inheritFlags.collectAncestorChainFds(ws, [ws]) + self.assertEqual(chain, []) + + +# =========================================================================== +# FeatureDataSource: buildEffectiveByWorkspaceFds +# =========================================================================== + +class TestBuildEffectiveByWorkspaceFds(unittest.TestCase): + def test_walk_mode(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl + result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="walk") + self.assertEqual(result, {"ws": True, "t": False, "r": False}) + + def test_aggregate_mode(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="aggregate") + self.assertEqual(result["ws"], "mixed") + self.assertEqual(result["t"], False) + self.assertEqual(result["r"], False) + + +# =========================================================================== +# resolveEffectiveForPath (with and without own record) +# =========================================================================== + +class TestResolveEffectiveForPath(unittest.TestCase): + def test_with_exact_record(self): + root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=False) + leaf = _ds("l", "/folder/sub", neutralize=False) + allDs = [root, leaf] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub", allDs) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_without_record_inherits_from_ancestor(self): + root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=True) + allDs = [root] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/deep/path/file.txt", allDs) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_without_record_inherits_from_closest_ancestor(self): + root = _ds("r", "/", neutralize=True, ragIndexEnabled=True) + mid = _ds("m", "/folder", neutralize=False, ragIndexEnabled=False) + allDs = [root, mid] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub/file.txt", allDs) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_without_record_no_ancestors_returns_defaults(self): + allDs: list = [] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/path", allDs) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "personal") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_connection_root_covers_service_subtree(self): + connRoot = _ds("cr", "/", neutralize=True, sourceType="msft") + allDs = [connRoot] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/sites/intranet", allDs) + self.assertEqual(result["effectiveNeutralize"], True) + + def test_exact_record_with_aggregate_mixed(self): + root = _ds("r", "/", neutralize=True) + leaf = _ds("l", "/sub", neutralize=False) + allDs = [root, leaf] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/", allDs, mode="aggregate") + self.assertEqual(result["effectiveNeutralize"], "mixed") + + +class TestResolveEffectiveForFds(unittest.TestCase): + def test_with_exact_record(self): + ws = _fds("ws", tableName="*", neutralize=True, scope="mandate") + tbl = _fds("t", tableName="Pos", neutralize=False, scope="personal") + allFds = [ws, tbl] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "personal") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_without_record_inherits_from_workspace_wildcard(self): + ws = _fds("ws", tableName="*", neutralize=True, scope="mandate", ragIndexEnabled=True) + allFds = [ws] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Unknown", None, allFds) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_without_record_no_ancestors_returns_defaults(self): + allFds: list = [] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "personal") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_rag_inherits_when_table_overrides_neutralize_only(self): + """Tables that override only neutralize must still inherit RAG from parent.""" + ws = _fds("ws", tableName="*", ragIndexEnabled=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + allFds = [ws, tbl] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds) + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_rag_aggregate_mixed_when_descendants_diverge(self): + ws = _fds("ws", tableName="*", ragIndexEnabled=True) + tbl = _fds("t", tableName="Pos", ragIndexEnabled=False) + allFds = [ws, tbl] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "*", None, allFds, mode="aggregate") + self.assertEqual(result["effectiveRagIndexEnabled"], "mixed") + + def test_inheritable_fds_flags_includes_rag(self): + self.assertIn("ragIndexEnabled", _inheritFlags._INHERITABLE_FDS_FLAGS) + self.assertIn("neutralize", _inheritFlags._INHERITABLE_FDS_FLAGS) + self.assertIn("scope", _inheritFlags._INHERITABLE_FDS_FLAGS) + + +# =========================================================================== +# FDS cascade resets RAG (in addition to neutralize and scope) +# =========================================================================== + +class TestCascadeResetFdsRag(unittest.TestCase): + def test_cascade_resets_rag_on_descendants(self): + ws = _fds("ws", tableName="*") + tbl = _fds("t", tableName="Pos", ragIndexEnabled=False) + allFds = [ws, tbl] + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = allFds + rootIf.db.recordModify = MagicMock() + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled") + self.assertIn("t", result) + rootIf.db.recordModify.assert_called() + + +# =========================================================================== +# Path normalization +# =========================================================================== + class TestPathNormalization(unittest.TestCase): def test_empty_path_normalises_to_root(self): self.assertEqual(_inheritFlags._normalisePath(""), "/") diff --git a/tests/unit/teamsbot/test_directorPrompts.py b/tests/unit/teamsbot/test_directorPrompts.py index f136438a..b8bdaafc 100644 --- a/tests/unit/teamsbot/test_directorPrompts.py +++ b/tests/unit/teamsbot/test_directorPrompts.py @@ -42,7 +42,7 @@ from modules.features.teamsbot.datamodelTeamsbot import ( from modules.features.teamsbot.service import ( TeamsbotService, _activeServices, - _sessionEvents, + sessionEvents, getActiveService, ) @@ -152,10 +152,10 @@ def _buildService() -> TeamsbotService: def _resetGlobals(): """Avoid cross-test bleed in module-level globals.""" _activeServices.clear() - _sessionEvents.clear() + sessionEvents.clear() yield _activeServices.clear() - _sessionEvents.clear() + sessionEvents.clear() # ============================================================================ @@ -251,7 +251,7 @@ class TestBuildPersistentDirectorContext: ] rendered = svc._buildPersistentDirectorContext() assert "OPERATOR_DIRECTIVES" in rendered - assert "- Antworte immer in Englisch." in rendered + assert "Antworte immer in Englisch." in rendered assert "private" in rendered def test_skipsBlankText(self): @@ -261,7 +261,7 @@ class TestBuildPersistentDirectorContext: {"id": "p2", "text": "Sei hoeflich."}, ] rendered = svc._buildPersistentDirectorContext() - assert "- Sei hoeflich." in rendered + assert "Sei hoeflich." in rendered assert "p1" not in rendered # the blank one is filtered out def test_allBlankPromptsResultInEmpty(self): From 9773c00bca3b48216c86adb6e733ae1060f37851 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Tue, 19 May 2026 17:38:18 +0200 Subject: [PATCH 4/6] trustee budget fix --- modules/features/trustee/mainTrustee.py | 10 ++- .../actions/refreshAccountingData.py | 83 +++++++++++++++---- 2 files changed, 74 insertions(+), 19 deletions(-) diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py index b3f7cdcf..41903211 100644 --- a/modules/features/trustee/mainTrustee.py +++ b/modules/features/trustee/mainTrustee.py @@ -484,8 +484,14 @@ TEMPLATE_WORKFLOWS = [ "3. Kurzer Management-Summary-Absatz (3-5 Saetze) UNTER dem Chart " "mit den 3 groessten Abweichungen (>10%) und einer fachlichen " "Einschaetzung.\n\n" - "Verwende die uebergebene Budget-Datei als Soll-Quelle und die im " - "Kontext bereitgestellten Buchhaltungsdaten als Ist-Quelle.\n" + "DATENQUELLEN:\n" + "- SOLL (Budget): Aus der uebergebenen Budget-Datei (Excel).\n" + "- IST (Buchhaltung): Verwende AUSSCHLIESSLICH das Feld " + "\"closingBalance\" aus \"accountSummary\" im Kontext-JSON. " + "Dort steht pro Konto GENAU EIN Ist-Wert (Jahresabschluss-Saldo). " + "Fuer Quartals-Budgets stehen zusaetzlich Q1/Q2/Q3/Q4-Felder bereit. " + "SUMMIERE NIEMALS mehrere Zeilen oder Journal-Eintraege auf -- der " + "closingBalance in accountSummary ist bereits der korrekte Ist-Wert.\n\n" "WICHTIG: Erstelle KEINEN separaten Chart pro Konto. Nur EIN " "Uebersichts-Chart ueber alle Konten ist gewuenscht.\n\n" "Hinweis: Das documentTheme ist 'finance'. Wenn du ein Dokument erstellst, " diff --git a/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py b/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py index 6ff5641c..0d6e737c 100644 --- a/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py +++ b/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py @@ -38,6 +38,52 @@ def _tsToIso(ts) -> Optional[str]: _SYNC_THRESHOLD_SECONDS = 3600 +def _buildAccountSummary(accountMap: Dict[str, dict], balances: list, year: int) -> list: + """Aggregate balance records into one row per account for *year*. + + For each account the annual balance record (``periodMonth == 0``) of + *year* is preferred. If that row is missing, we also check the + previous year's annual record so that YTD carry-forwards are visible. + Additionally, quarterly closing balances (Q1-Q4) are derived from the + monthly records so the AI can compare against quarterly budgets. + """ + bestClosing: Dict[str, float] = {} + quarterClosing: Dict[str, Dict[str, float]] = {} + + for b in balances: + acct = b.get("accountNumber", "") + bYear = b.get("periodYear", 0) + bMonth = b.get("periodMonth", 0) + closing = b.get("closingBalance", 0) or 0 + + if bYear == year and bMonth == 0: + bestClosing[acct] = closing + + if bYear == year and bMonth in (3, 6, 9, 12): + qLabel = f"Q{bMonth // 3}" + quarterClosing.setdefault(acct, {})[qLabel] = closing + + if acct not in bestClosing and bYear == year - 1 and bMonth == 0: + bestClosing[acct] = closing + + summary = [] + for nr in sorted(accountMap.keys()): + info = accountMap[nr] + row = { + "account": nr, + "label": info.get("label", ""), + "type": info.get("type", ""), + "group": info.get("group", ""), + "closingBalance": round(bestClosing.get(nr, 0), 2), + } + qData = quarterClosing.get(nr, {}) + for q in ("Q1", "Q2", "Q3", "Q4"): + if q in qData: + row[q] = round(qData[q], 2) + summary.append(row) + return summary + + async def refreshAccountingData(self, parameters: Dict[str, Any]) -> ActionResult: """Import/refresh accounting data from the configured external system. @@ -133,7 +179,13 @@ async def refreshAccountingData(self, parameters: Dict[str, Any]) -> ActionResul def _exportAccountingData(trusteeInterface, featureInstanceId: str, dateFrom: str = None, dateTo: str = None) -> str: - """Export accounting data (accounts, balances, journal entries+lines) as compact JSON for downstream AI nodes.""" + """Export accounting data as compact JSON for downstream AI nodes. + + Produces a pre-aggregated ``accountSummary`` (one row per account with + a single *Ist* value) so the AI does not have to navigate thousands of + raw balance records. Raw per-month balances are deliberately omitted to + avoid confusion and reduce payload size. + """ from modules.features.trustee.datamodelFeatureTrustee import ( TrusteeDataAccount, TrusteeDataJournalEntry, @@ -155,17 +207,9 @@ def _exportAccountingData(trusteeInterface, featureInstanceId: str, dateFrom: st } balances = trusteeInterface.db.getRecordset(TrusteeDataAccountBalance, recordFilter=baseFilter) or [] - balanceList = [] - for b in balances: - balanceList.append({ - "account": b.get("accountNumber", ""), - "year": b.get("periodYear", 0), - "month": b.get("periodMonth", 0), - "opening": b.get("openingBalance", 0), - "debit": b.get("debitTotal", 0), - "credit": b.get("creditTotal", 0), - "closing": b.get("closingBalance", 0), - }) + + currentYear = _dt.now(tz=_tz.utc).year + accountSummary = _buildAccountSummary(accountMap, balances, currentYear) entries = trusteeInterface.db.getRecordset(TrusteeDataJournalEntry, recordFilter=baseFilter) or [] fromTs = _isoToTs(dateFrom) @@ -205,21 +249,26 @@ def _exportAccountingData(trusteeInterface, featureInstanceId: str, dateFrom: st }) export = { - "accounts": list(accountMap.values()), - "balances": balanceList, + "accountSummary": accountSummary, "journalLines": lineList, "meta": { "accountCount": len(accountMap), "entryCount": len(entryMap), "lineCount": len(lineList), - "balanceCount": len(balanceList), + "summaryYear": currentYear, "dateFrom": dateFrom, "dateTo": dateTo, + "hint": ( + "accountSummary contains ONE row per account with the " + "current-year closing balance (Ist). Use this for " + "budget comparisons. journalLines lists individual " + "bookings for drill-down." + ), }, } result = json.dumps(export, ensure_ascii=False, default=str) - logger.info("Exported accounting data: %d accounts, %d entries, %d lines, %d balances (%d bytes)", - len(accountMap), len(entryMap), len(lineList), len(balanceList), len(result)) + logger.info("Exported accounting data: %d accounts (summary), %d entries, %d lines (%d bytes)", + len(accountSummary), len(entryMap), len(lineList), len(result)) return result except Exception as e: logger.warning("Could not export accounting data: %s", e) From a173fab15ff3d9fde82a44874018c4c919f47e45 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Tue, 19 May 2026 17:42:24 +0200 Subject: [PATCH 5/6] fix mandate res --- modules/routes/routeHelpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/routes/routeHelpers.py b/modules/routes/routeHelpers.py index bb1386af..b58ffc6d 100644 --- a/modules/routes/routeHelpers.py +++ b/modules/routes/routeHelpers.py @@ -41,7 +41,7 @@ def resolveMandateLabels(ids: List[str]) -> Dict[str, Optional[str]]: m = mMap.get(mid) label = (getattr(m, "label", None) or getattr(m, "name", None)) if m else None if not label: - logger.warning("resolveMandateLabels: no label for id=%s (found=%s)", mid, m is not None) + logger.debug("resolveMandateLabels: no label for id=%s (found=%s)", mid, m is not None) result[mid] = label or None return result @@ -57,7 +57,7 @@ def resolveInstanceLabels(ids: List[str]) -> Dict[str, Optional[str]]: fi = featureIface.getFeatureInstance(iid) label = fi.label if fi and fi.label else None if not label: - logger.warning("resolveInstanceLabels: no label for id=%s (found=%s)", iid, fi is not None) + logger.debug("resolveInstanceLabels: no label for id=%s (found=%s)", iid, fi is not None) result[iid] = label return result @@ -104,7 +104,7 @@ def resolveRoleLabels(ids: List[str]) -> Dict[str, Optional[str]]: out[rid] = r.get("roleLabel") or None for rid in ids: if out.get(rid) is None: - logger.warning("resolveRoleLabels: no label for id=%s", rid) + logger.debug("resolveRoleLabels: no label for id=%s", rid) return out From 09c6d33deca58e715c8e6256a2c9ed0021eee44d Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Tue, 19 May 2026 22:14:00 +0200 Subject: [PATCH 6/6] fixed expenses workflow --- .../mainServiceSharepoint.py | 27 +++++++------------ .../methodTrustee/actions/processDocuments.py | 21 ++++++++++++--- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py b/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py index 483d7fbe..4fd1fb36 100644 --- a/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py +++ b/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py @@ -327,27 +327,20 @@ class SharepointService: return None async def uploadFile(self, siteId: str, folderPath: str, fileName: str, content: bytes) -> Dict[str, Any]: - """Upload a file to SharePoint.""" - try: - # Clean the path - cleanPath = folderPath.lstrip('/') - uploadPath = f"{cleanPath.rstrip('/')}/{fileName}" - endpoint = f"sites/{siteId}/drive/root:/{uploadPath}:/content" + """Upload a file to SharePoint. Raises on failure.""" + cleanPath = folderPath.lstrip('/') + uploadPath = f"{cleanPath.rstrip('/')}/{fileName}" + endpoint = f"sites/{siteId}/drive/root:/{uploadPath}:/content" - logger.info(f"Uploading file to: {endpoint}") + logger.info(f"Uploading file to: {endpoint}") - result = await self._makeGraphApiCall(endpoint, method="PUT", data=content) + result = await self._makeGraphApiCall(endpoint, method="PUT", data=content) - if "error" in result: - logger.error(f"Upload failed: {result['error']}") - return result + if "error" in result: + raise Exception(f"Upload failed: {result['error']}") - logger.info(f"File uploaded successfully: {fileName}") - return result - - except Exception as e: - logger.error(f"Error uploading file: {str(e)}") - return {"error": f"Error uploading file: {str(e)}"} + logger.info(f"File uploaded successfully: {fileName}") + return result async def downloadFile(self, siteId: str, fileId: str) -> Optional[bytes]: """Download a file from SharePoint.""" diff --git a/modules/workflows/methods/methodTrustee/actions/processDocuments.py b/modules/workflows/methods/methodTrustee/actions/processDocuments.py index b05e25f4..29d5ab13 100644 --- a/modules/workflows/methods/methodTrustee/actions/processDocuments.py +++ b/modules/workflows/methods/methodTrustee/actions/processDocuments.py @@ -247,16 +247,29 @@ def _resolveDocumentList(documentListParam, services) -> List[tuple]: if isinstance(first, dict) and ("documentData" in first or "documentName" in first): for doc in documentListParam: rawData = doc.get("documentData") - logger.debug("_resolveDocumentList: doc keys=%s documentData type=%s documentData truthy=%s", list(doc.keys()), type(rawData).__name__, bool(rawData)) + fileId = (doc.get("validationMetadata") or {}).get("fileId") or doc.get("fileId", "") + fileName = doc.get("documentName") or doc.get("fileName") or "document" + mimeType = doc.get("mimeType") or doc.get("documentMimeType") or "application/json" + + # When documentData was persisted as binary (_hasBinaryData), read it + # back from file storage via the chat service. + if not rawData and doc.get("_hasBinaryData") and fileId: + chatService = getattr(services, "chat", None) + if chatService: + try: + rawBytes = chatService.getFileData(fileId) + if rawBytes: + rawData = rawBytes.decode("utf-8") if isinstance(rawBytes, bytes) else rawBytes + except Exception as e: + logger.debug("_resolveDocumentList: failed to read binary for fileId=%s: %s", fileId, e) + + logger.debug("_resolveDocumentList: doc keys=%s documentData type=%s documentData truthy=%s", list(doc.keys()), type(rawData).__name__ if rawData else "NoneType", bool(rawData)) if not rawData: continue try: data = json.loads(rawData) if isinstance(rawData, str) else rawData except (json.JSONDecodeError, TypeError): continue - fileId = (doc.get("validationMetadata") or {}).get("fileId") or doc.get("fileId", "") - fileName = doc.get("documentName") or doc.get("fileName") or "document" - mimeType = doc.get("mimeType") or doc.get("documentMimeType") or "application/json" results.append((data, fileId, fileName, mimeType)) if results: return results