From 2bb65c230369b3b54e8f8c06e32d8296d1c8e100 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 17 May 2026 20:38:37 +0200
Subject: [PATCH 1/6] db connection pooling and rag limit transparency
---
app.py | 11 +-
modules/connectors/connectorDbPostgre.py | 1180 ++++++++++-------
.../realEstate/interfaceFeatureRealEstate.py | 8 +-
modules/interfaces/interfaceDbBilling.py | 18 +-
modules/interfaces/interfaceDbKnowledge.py | 16 +
modules/interfaces/interfaceDbManagement.py | 32 +-
modules/interfaces/interfaceRbac.py | 12 +-
modules/routes/routeHelpers.py | 2 +-
modules/routes/routeRagInventory.py | 65 +-
modules/routes/routeWorkflowDashboard.py | 4 +-
.../coreTools/_featureSubAgentTools.py | 15 +-
.../subConnectorSyncClickup.py | 43 +-
.../subConnectorSyncGdrive.py | 43 +-
.../subConnectorSyncKdrive.py | 46 +-
.../subConnectorSyncSharepoint.py | 56 +-
modules/shared/configuration.py | 63 +-
modules/shared/dbMultiTenantOptimizations.py | 101 +-
modules/shared/gdprDeletion.py | 178 ++-
scripts/stage0_filefolder_schema_check.py | 2 +-
.../test_connectorDbPostgre_failLoud.py | 82 +-
.../test_connectorDbPostgre_pool.py | 304 +++++
tests/unit/interfaces/test_folderRbac.py | 10 +
tests/unit/routes/test_folder_crud.py | 10 +
23 files changed, 1519 insertions(+), 782 deletions(-)
create mode 100644 tests/unit/connectors/test_connectorDbPostgre_pool.py
diff --git a/app.py b/app.py
index 7a4ed4d4..d94c7dd5 100644
--- a/app.py
+++ b/app.py
@@ -438,7 +438,16 @@ async def lifespan(app: FastAPI):
logger.error(f"Feature '{featureName}' failed to stop: {e}")
except Exception as e:
logger.warning(f"Could not shutdown feature containers: {e}")
-
+
+ # --- Close all PostgreSQL connection pools ---
+ # Must run LAST: feature `onStop` hooks may still issue DB calls during
+ # shutdown. Once we tear down the pools, no more borrows are possible.
+ try:
+ from modules.connectors.connectorDbPostgre import closeAllPools
+ closeAllPools()
+ except Exception as e:
+ logger.warning(f"Closing DB connection pools failed: {e}")
+
logger.info("Application has been shut down")
diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py
index a6893396..f1a34f70 100644
--- a/modules/connectors/connectorDbPostgre.py
+++ b/modules/connectors/connectorDbPostgre.py
@@ -2,9 +2,12 @@
# All rights reserved.
import contextvars
import re
+import time
import psycopg2
import psycopg2.extras
+import psycopg2.pool
import logging
+from contextlib import contextmanager
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
import uuid
from pydantic import BaseModel, Field
@@ -44,24 +47,6 @@ class DatabaseQueryError(RuntimeError):
self.original = original
-def _rollbackQuietly(connection) -> None:
- """Restore the connection state after a failed query.
-
- Postgres puts the connection in an error state after any failed
- statement; subsequent queries on the same connection raise
- ``InFailedSqlTransaction`` until we rollback. We swallow rollback
- errors because the original query error is what the caller should
- see — a secondary rollback failure typically means the connection
- is gone and will be reopened on the next ``_ensure_connection``.
- """
- if connection is None:
- return
- try:
- connection.rollback()
- except Exception:
- pass
-
-
class SystemTable(PowerOnModel):
"""Data model for system table entries"""
@@ -203,9 +188,174 @@ def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
-# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
-# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
-# contextvars to avoid races when concurrent requests share the same connector.
+# ---------------------------------------------------------------------------
+# Connection pool registry
+# ---------------------------------------------------------------------------
+# psycopg2 connections are NOT thread-safe; sharing one connection across the
+# FastAPI threadpool (sync `def` routes) or across multiple async tasks that
+# happen to be active simultaneously results in either
+# `OperationalError: another command is already in progress` or — far worse —
+# an unbounded hang in `recv()` on the connection socket, because two cursors
+# wait for one server response.
+#
+# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity
+# (`host`, `db`, `port`). Every DB call borrows a connection from the pool,
+# runs its query, and returns the connection. The pool itself guarantees that
+# no two callers share the same psycopg2 connection at the same time.
+#
+# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted
+# instead of hanging forever and poisoning the connection. `connect_timeout=10`
+# prevents indefinite blocking when the Postgres host is unreachable.
+
+_DEFAULT_POOL_MIN = 2
+_DEFAULT_POOL_MAX = 20
+_STATEMENT_TIMEOUT_MS = 30000
+_CONNECT_TIMEOUT_S = 10
+# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool
+# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes
+# bursty workloads (50 concurrent route calls against a max=20 pool) fail
+# spuriously. `borrowConn()` therefore retries with a small backoff up to
+# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up.
+_BORROW_WAIT_TIMEOUT_S = 30.0
+_BORROW_WAIT_BACKOFF_S = 0.05
+
+
+def _resolvePoolMax() -> int:
+ """Pool size is configurable via `DB_POOL_MAX_CONN` (default 20)."""
+ try:
+ return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX))
+ except (ValueError, TypeError):
+ return _DEFAULT_POOL_MAX
+
+
+class _PoolRegistry:
+ """Process-wide registry of `ThreadedConnectionPool` instances.
+
+ Keyed by `(host, database, port)` so that two `DatabaseConnector` instances
+ pointing at the same physical database share one pool. Lazy-initialised and
+ thread-safe.
+ """
+
+ _pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {}
+ _lock = threading.Lock()
+
+ @classmethod
+ def getPool(
+ cls,
+ *,
+ dbHost: str,
+ dbDatabase: str,
+ dbUser: str,
+ dbPassword: str,
+ dbPort: int,
+ ) -> psycopg2.pool.ThreadedConnectionPool:
+ port = int(dbPort) if dbPort is not None else 5432
+ key = (dbHost, dbDatabase, port)
+ # Fast path: pool exists.
+ pool = cls._pools.get(key)
+ if pool is not None:
+ return pool
+ # Slow path: create exactly one pool per key, even under contention.
+ with cls._lock:
+ pool = cls._pools.get(key)
+ if pool is not None:
+ return pool
+ poolMax = _resolvePoolMax()
+ options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}"
+ try:
+ pool = psycopg2.pool.ThreadedConnectionPool(
+ _DEFAULT_POOL_MIN,
+ poolMax,
+ host=dbHost,
+ port=port,
+ database=dbDatabase,
+ user=dbUser,
+ password=dbPassword,
+ client_encoding="utf8",
+ cursor_factory=psycopg2.extras.RealDictCursor,
+ connect_timeout=_CONNECT_TIMEOUT_S,
+ options=options,
+ )
+ except Exception as e:
+ logger.error(
+ "Failed to create connection pool for db=%s host=%s: %s",
+ dbDatabase, dbHost, e,
+ )
+ raise
+ cls._pools[key] = pool
+ logger.debug(
+ "Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)",
+ dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS,
+ )
+ return pool
+
+ @classmethod
+ def closeAll(cls) -> None:
+ """Close every pool. Call once during FastAPI shutdown."""
+ with cls._lock:
+ for key, pool in list(cls._pools.items()):
+ try:
+ pool.closeall()
+ except Exception as e:
+ logger.warning("Error closing pool %s: %s", key, e)
+ cls._pools.clear()
+ logger.info("All database connection pools closed")
+
+
+def closeAllPools() -> None:
+ """Public entry point for FastAPI lifespan shutdown hook."""
+ _PoolRegistry.closeAll()
+
+
+class _ConnectionShim:
+ """Backward-compatibility stand-in for the old `DatabaseConnector.connection`.
+
+ Old callers used patterns like::
+
+ db.connection.commit()
+ if db.connection and not db.connection.closed: ...
+
+ These are now no-ops because the pool owns the connection lifecycle.
+ Direct cursor access through `self.connection.cursor()` is intentionally
+ blocked with a clear error so silent breakage is impossible — every such
+ call site must migrate to `db.borrowCursor()`.
+ """
+
+ closed = False
+
+ def __bool__(self) -> bool:
+ return True
+
+ def commit(self) -> None:
+ return
+
+ def rollback(self) -> None:
+ return
+
+ def close(self) -> None:
+ return
+
+ def cursor(self, *args, **kwargs):
+ raise RuntimeError(
+ "DatabaseConnector.connection.cursor() is no longer supported. "
+ "Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement "
+ "transactions) so the connection is borrowed from and returned to "
+ "the pool correctly."
+ )
+
+
+_CONNECTION_SHIM = _ConnectionShim()
+
+
+# ---------------------------------------------------------------------------
+# Connector cache (lightweight wrappers — actual connections live in the pool)
+# ---------------------------------------------------------------------------
+# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`,
+# interfaces) ask for a connector via `getCachedConnector(...)` and expect to
+# get back the same object on subsequent calls. Now that real connection
+# multiplexing happens at the pool layer, the cache returns lightweight
+# `DatabaseConnector` wrappers — they hold no connection themselves, only the
+# DSN params and a reference to the shared pool.
_MAX_CACHED_CONNECTORS = 32
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
@@ -223,22 +373,36 @@ def getCachedConnector(
dbPort: int = None,
userId: str = None,
) -> "DatabaseConnector":
- """Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
- Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy.
+ """Return a cached `DatabaseConnector` wrapper for `(host, database, port)`.
+
+ Two layers of caching, both intentional:
+
+ 1. **Pool layer** (`_PoolRegistry`) — owns the actual psycopg2 connections.
+ Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared
+ across every wrapper that points at the same physical database. This
+ is what actually saves Postgres connection slots.
+
+ 2. **Wrapper layer** (this function) — caches the `DatabaseConnector`
+ Python object. Each wrapper triggers an `initDbSystem()` on first
+ instantiation (CREATE DATABASE if missing, CREATE TABLE _system,
+ pool warm-up). Caching the wrapper avoids paying that bootstrap cost
+ on every request and stops the log from filling with "PostgreSQL
+ database system initialized" lines.
+
+ `userId` is request-scoped via the `_current_user_id` contextvar so two
+ concurrent requests sharing the same cached wrapper still produce
+ correct `sysCreatedBy` / `sysModifiedBy` audit fields.
"""
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
with _connector_cache_lock:
if key not in _connector_cache:
- # Evict oldest if at capacity
+ # FIFO eviction. Connectors are now lightweight (no per-instance
+ # connection), so eviction is purely a memory bookkeeping concern;
+ # the underlying pool stays alive in `_PoolRegistry`.
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
oldest_key = _connector_cache_order.pop(0)
- if oldest_key in _connector_cache:
- try:
- _connector_cache[oldest_key].close(forceClose=True)
- except Exception as e:
- logger.warning(f"Error closing evicted connector: {e}")
- del _connector_cache[oldest_key]
+ _connector_cache.pop(oldest_key, None)
_connector_cache[key] = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
@@ -282,34 +446,38 @@ class DatabaseConnector:
# Set userId (default to empty string if None)
self.userId = userId if userId is not None else ""
- # Initialize database system first (creates database if needed)
- self.connection = None
+ # No per-instance connection any more — real connections live in the
+ # shared `_PoolRegistry` pool. `_isCachedShared` is retained because
+ # `close(forceClose=False)` callers (interface __del__) still ask.
self._isCachedShared = False
- self.initDbSystem()
- # No caching needed with proper database - PostgreSQL handles performance
-
- # Thread safety
- self._lock = threading.Lock()
-
- # pgvector extension state
+ # pgvector extension state (cached per connector instance — cheap)
self._vectorExtensionEnabled = False
- # Initialize system table
+ # System table bootstrap: create database, system table, ensure metadata.
self._systemTableName = "_system"
+ self.initDbSystem()
self._initializeSystemTable()
def initDbSystem(self):
- """Initialize the database system - creates database and tables."""
- try:
- # Create database if it doesn't exist
- self._create_database_if_not_exists()
+ """Bootstrap the physical database and the `_system` metadata table.
- # Create tables
+ Uses a short-lived autocommit connection (NOT the pool) because
+ `CREATE DATABASE` cannot run inside a transaction block. Also warms up
+ the pool by acquiring it for this `(host, db, port)` once.
+ """
+ try:
+ self._create_database_if_not_exists()
self._create_tables()
- # Establish connection to the database
- self._connect()
+ # Warm the pool so the first request doesn't pay for socket setup.
+ _PoolRegistry.getPool(
+ dbHost=self.dbHost,
+ dbDatabase=self.dbDatabase,
+ dbUser=self.dbUser,
+ dbPassword=self.dbPassword,
+ dbPort=self.dbPort,
+ )
logger.debug(
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
@@ -319,10 +487,121 @@ class DatabaseConnector:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise
- def _create_database_if_not_exists(self):
- """Create the database if it doesn't exist."""
+ @property
+ def connection(self) -> "_ConnectionShim":
+ """Backward-compat shim — see `_ConnectionShim` docstring."""
+ return _CONNECTION_SHIM
+
+ def _ensure_connection(self) -> None:
+ """No-op for backward compatibility.
+
+ Previously this method tested-or-reconnected the per-instance socket.
+ With pooling, the `ThreadedConnectionPool` re-establishes broken
+ connections lazily on the next `getconn()`. Kept as a no-op so that
+ legacy call sites continue to compile.
+ """
+ return
+
+ @contextmanager
+ def borrowCursor(self):
+ """Borrow a cursor for one short SQL block.
+
+ Convenience wrapper around `borrowConn()` for the common pattern:
+
+ with db.borrowCursor() as cursor:
+ cursor.execute(...)
+ rows = cursor.fetchall()
+
+ Replaces the legacy `with db.connection.cursor() as cursor:` pattern.
+ Commit/rollback and pool return are handled automatically — callers
+ must NOT call `.commit()`/`.rollback()` on the cursor themselves.
+ """
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ yield cursor
+
+ @contextmanager
+ def borrowConn(self):
+ """Borrow a connection from the pool for the duration of the block.
+
+ Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises
+ `psycopg2.pool.PoolError` immediately when the pool is at its `maxconn`
+ limit — it does NOT block. We wrap that with a bounded busy-wait so
+ bursty workloads (e.g. 50 concurrent route calls against a max=20
+ pool) queue up instead of failing. If no connection becomes available
+ within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a
+ genuinely deadlocked pool surfaces as a real error.
+
+ On normal exit the current transaction is committed (idempotent for
+ read-only queries — leaves the connection in a clean state for the
+ next borrower). On exception the transaction is rolled back and the
+ exception propagates. The connection is **always** returned to the
+ pool, even when commit/rollback fails — otherwise a single bad query
+ would leak a slot and eventually exhaust the pool.
+ """
+ pool = _PoolRegistry.getPool(
+ dbHost=self.dbHost,
+ dbDatabase=self.dbDatabase,
+ dbUser=self.dbUser,
+ dbPassword=self.dbPassword,
+ dbPort=self.dbPort,
+ )
+ conn = self._acquireConn(pool)
+ try:
+ yield conn
+ except Exception:
+ try:
+ conn.rollback()
+ except Exception:
+ pass
+ raise
+ else:
+ # Best-effort commit so the connection goes back to the pool with
+ # no in-flight transaction. Failure here is non-fatal — the pool
+ # will re-establish the socket on the next `getconn()` if needed.
+ try:
+ conn.commit()
+ except Exception:
+ try:
+ conn.rollback()
+ except Exception:
+ pass
+ finally:
+ try:
+ pool.putconn(conn)
+ except Exception as e:
+ logger.warning("Failed to return connection to pool: %s", e)
+
+ @staticmethod
+ def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool):
+ """Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`.
+
+ psycopg2's pool throws on exhaustion instead of queueing — this helper
+ polls with a short backoff so callers see queue semantics.
+ """
+ deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S
+ attempt = 0
+ while True:
+ try:
+ return pool.getconn()
+ except psycopg2.pool.PoolError as e:
+ attempt += 1
+ if time.monotonic() >= deadline:
+ logger.error(
+ "Connection pool exhausted after %.1fs wait (%d retries)",
+ _BORROW_WAIT_TIMEOUT_S, attempt,
+ )
+ raise
+ time.sleep(_BORROW_WAIT_BACKOFF_S)
+
+ def _create_database_if_not_exists(self):
+ """Create the database if it doesn't exist.
+
+ Uses an autocommit connection on the `postgres` admin DB because
+ `CREATE DATABASE` cannot run inside a transaction block — so this
+ path intentionally does NOT use the pool.
+ """
try:
- # Use the configured user for database creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
@@ -330,23 +609,21 @@ class DatabaseConnector:
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
+ connect_timeout=_CONNECT_TIMEOUT_S,
)
conn.autocommit = True
-
- with conn.cursor() as cursor:
- # Check if database exists
- cursor.execute(
- "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
- )
- exists = cursor.fetchone()
-
- if not exists:
- # Create database with proper quoting for names with hyphens
- quoted_db_name = f'"{self.dbDatabase}"'
- cursor.execute(f"CREATE DATABASE {quoted_db_name}")
- logger.info(f"Created database: {self.dbDatabase}")
-
- conn.close()
+ try:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
+ )
+ exists = cursor.fetchone()
+ if not exists:
+ quoted_db_name = f'"{self.dbDatabase}"'
+ cursor.execute(f"CREATE DATABASE {quoted_db_name}")
+ logger.info(f"Created database: {self.dbDatabase}")
+ finally:
+ conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create database: {e}")
@@ -356,9 +633,12 @@ class DatabaseConnector:
)
def _create_tables(self):
- """Create only the system table - application tables are created by interfaces."""
+ """Create the `_system` table.
+
+ Uses a short-lived autocommit connection (not the pool) — runs exactly
+ once at connector creation.
+ """
try:
- # Use the configured user for table creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
@@ -366,23 +646,24 @@ class DatabaseConnector:
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
+ connect_timeout=_CONNECT_TIMEOUT_S,
)
conn.autocommit = True
-
- with conn.cursor() as cursor:
- # Create only the system table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS _system (
- id SERIAL PRIMARY KEY,
- table_name VARCHAR(255) UNIQUE NOT NULL,
- initial_id VARCHAR(255) NOT NULL,
- "sysCreatedAt" DOUBLE PRECISION,
- "sysCreatedBy" VARCHAR(255),
- "sysModifiedAt" DOUBLE PRECISION,
- "sysModifiedBy" VARCHAR(255)
- )
- """)
- conn.close()
+ try:
+ with conn.cursor() as cursor:
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS _system (
+ id SERIAL PRIMARY KEY,
+ table_name VARCHAR(255) UNIQUE NOT NULL,
+ initial_id VARCHAR(255) NOT NULL,
+ "sysCreatedAt" DOUBLE PRECISION,
+ "sysCreatedBy" VARCHAR(255),
+ "sysModifiedAt" DOUBLE PRECISION,
+ "sysModifiedBy" VARCHAR(255)
+ )
+ """)
+ finally:
+ conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
@@ -391,67 +672,26 @@ class DatabaseConnector:
)
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
- def _connect(self):
- """Establish connection to PostgreSQL database."""
- try:
- # Use configured user for main connection with proper parameter handling
- self.connection = psycopg2.connect(
- host=self.dbHost,
- port=self.dbPort,
- database=self.dbDatabase,
- user=self.dbUser,
- password=self.dbPassword,
- client_encoding="utf8",
- cursor_factory=psycopg2.extras.RealDictCursor,
- )
- self.connection.autocommit = False # Use transactions
- except Exception as e:
- logger.error(f"Failed to connect to PostgreSQL: {e}")
- raise
-
- def _ensure_connection(self):
- """Ensure database connection is alive, reconnect if necessary."""
- try:
- if self.connection is None or self.connection.closed:
- self._connect()
- else:
- # Test connection with a simple query
- with self.connection.cursor() as cursor:
- cursor.execute("SELECT 1")
- except Exception as e:
- logger.warning(f"Connection lost, reconnecting: {e}")
- self._connect()
-
def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet."""
try:
- # First ensure the system table exists
self._ensureTableExists(SystemTable)
-
- with self.connection.cursor() as cursor:
- # Check if system table has any data
- cursor.execute('SELECT COUNT(*) FROM "_system"')
- row = cursor.fetchone()
- count = row["count"] if row else 0
-
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute('SELECT COUNT(*) FROM "_system"')
+ cursor.fetchone() # noqa: just verifies table is readable
except Exception as e:
logger.error(f"Error initializing system table: {e}")
- self.connection.rollback()
raise
def _loadSystemTable(self) -> Dict[str, str]:
"""Loads the system table with the initial IDs."""
try:
- with self.connection.cursor() as cursor:
- cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
- rows = cursor.fetchall()
-
- system_data = {}
- for row in rows:
- system_data[row["table_name"]] = row["initial_id"]
-
- return system_data
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
+ rows = cursor.fetchall()
+ return {row["table_name"]: row["initial_id"] for row in rows}
except Exception as e:
logger.error(f"Error loading system table: {e}")
return {}
@@ -459,75 +699,65 @@ class DatabaseConnector:
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
"""Saves the system table with the initial IDs."""
try:
- with self.connection.cursor() as cursor:
- # Clear existing data
- cursor.execute('DELETE FROM "_system"')
-
- # Insert new data
- for table_name, initial_id in data.items():
- cursor.execute(
- """
- INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt")
- VALUES (%s, %s, %s)
- """,
- (table_name, initial_id, getUtcTimestamp()),
- )
-
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute('DELETE FROM "_system"')
+ for table_name, initial_id in data.items():
+ cursor.execute(
+ """
+ INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt")
+ VALUES (%s, %s, %s)
+ """,
+ (table_name, initial_id, getUtcTimestamp()),
+ )
return True
except Exception as e:
logger.error(f"Error saving system table: {e}")
- self.connection.rollback()
return False
def _ensureSystemTableExists(self) -> bool:
"""Ensures the system table exists, creates it if it doesn't."""
try:
- self._ensure_connection()
-
- with self.connection.cursor() as cursor:
- # Check if system table exists
- cursor.execute(
- "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
- (self._systemTableName,),
- )
- exists = cursor.fetchone()["count"] > 0
-
- if not exists:
- # Create system table
- cursor.execute(f"""
- CREATE TABLE "{self._systemTableName}" (
- "table_name" VARCHAR(255) PRIMARY KEY,
- "initial_id" VARCHAR(255),
- "sysCreatedAt" DOUBLE PRECISION,
- "sysCreatedBy" VARCHAR(255),
- "sysModifiedAt" DOUBLE PRECISION,
- "sysModifiedBy" VARCHAR(255)
- )
- """)
- logger.info("System table created successfully")
- else:
- # Check if we need to add missing columns to existing table
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
cursor.execute(
- """
- SELECT column_name FROM information_schema.columns
- WHERE table_name = %s AND table_schema = 'public'
- """,
+ "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
(self._systemTableName,),
)
- existing_columns = [row["column_name"] for row in cursor.fetchall()]
+ exists = cursor.fetchone()["count"] > 0
- for sys_col, sys_sql in [
- ("sysCreatedAt", "DOUBLE PRECISION"),
- ("sysCreatedBy", "VARCHAR(255)"),
- ("sysModifiedAt", "DOUBLE PRECISION"),
- ("sysModifiedBy", "VARCHAR(255)"),
- ]:
- if sys_col not in existing_columns:
- cursor.execute(
- f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
+ if not exists:
+ cursor.execute(f"""
+ CREATE TABLE "{self._systemTableName}" (
+ "table_name" VARCHAR(255) PRIMARY KEY,
+ "initial_id" VARCHAR(255),
+ "sysCreatedAt" DOUBLE PRECISION,
+ "sysCreatedBy" VARCHAR(255),
+ "sysModifiedAt" DOUBLE PRECISION,
+ "sysModifiedBy" VARCHAR(255)
)
+ """)
+ logger.info("System table created successfully")
+ else:
+ cursor.execute(
+ """
+ SELECT column_name FROM information_schema.columns
+ WHERE table_name = %s AND table_schema = 'public'
+ """,
+ (self._systemTableName,),
+ )
+ existing_columns = [row["column_name"] for row in cursor.fetchall()]
+ for sys_col, sys_sql in [
+ ("sysCreatedAt", "DOUBLE PRECISION"),
+ ("sysCreatedBy", "VARCHAR(255)"),
+ ("sysModifiedAt", "DOUBLE PRECISION"),
+ ("sysModifiedBy", "VARCHAR(255)"),
+ ]:
+ if sys_col not in existing_columns:
+ cursor.execute(
+ f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
+ )
return True
except Exception as e:
logger.error(f"Error ensuring system table exists: {e}")
@@ -542,123 +772,113 @@ class DatabaseConnector:
return self._ensureSystemTableExists()
try:
- self._ensure_connection()
-
- with self.connection.cursor() as cursor:
- # Check if table exists by querying information_schema with case-insensitive search
- cursor.execute(
- """
- SELECT COUNT(*) FROM information_schema.tables
- WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
- """,
- (table,),
- )
- exists = cursor.fetchone()["count"] > 0
-
- if not exists:
- # Create table from Pydantic model
- self._create_table_from_model(cursor, table, model_class)
- logger.info(
- f"Created table '{table}' with columns from Pydantic model"
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ """
+ SELECT COUNT(*) FROM information_schema.tables
+ WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
+ """,
+ (table,),
)
- else:
- # Table exists: ensure all columns from model are present (simple additive migration)
- try:
- cursor.execute(
- """
- SELECT column_name, data_type
- FROM information_schema.columns
- WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
- """,
- (table,),
+ exists = cursor.fetchone()["count"] > 0
+
+ if not exists:
+ self._create_table_from_model(cursor, table, model_class)
+ logger.info(
+ f"Created table '{table}' with columns from Pydantic model"
)
- existing_column_rows = cursor.fetchall()
- existing_columns = {
- row["column_name"] for row in existing_column_rows
- }
- existing_column_types = {
- row["column_name"]: (row["data_type"] or "").lower()
- for row in existing_column_rows
- }
+ else:
+ # Table exists: ensure all columns from model are present (simple additive migration)
+ try:
+ cursor.execute(
+ """
+ SELECT column_name, data_type
+ FROM information_schema.columns
+ WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
+ """,
+ (table,),
+ )
+ existing_column_rows = cursor.fetchall()
+ existing_columns = {
+ row["column_name"] for row in existing_column_rows
+ }
+ existing_column_types = {
+ row["column_name"]: (row["data_type"] or "").lower()
+ for row in existing_column_rows
+ }
- # Desired columns based on model
- model_fields = getModelFields(model_class)
- desired_columns = set(["id"]) | set(model_fields.keys())
+ model_fields = getModelFields(model_class)
+ desired_columns = set(["id"]) | set(model_fields.keys())
- # Add missing columns
- for col in sorted(desired_columns - existing_columns):
- # Determine SQL type
- if col in ["id"]:
- continue # primary key exists already
- sql_type = model_fields.get(col)
- if not sql_type:
- sql_type = "TEXT"
- try:
- cursor.execute(
- f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
- )
- logger.info(
- f"Added missing column '{col}' ({sql_type}) to '{table}'"
- )
- except Exception as add_err:
- logger.warning(
- f"Could not add column '{col}' to '{table}': {add_err}"
- )
-
- # Column type migrations for existing tables.
- # TEXT→DOUBLE PRECISION handles three value shapes:
- # 1. NULL / empty string → NULL
- # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT
- # 3. Plain numeric string like "3.14" → direct cast
- _TEXT_TO_DOUBLE = (
- 'DOUBLE PRECISION USING CASE'
- ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL'
- ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\''
- ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)'
- ' ELSE NULLIF("{col}", \'\')::double precision'
- ' END'
- )
- _SAFE_TYPE_CHANGES = {
- ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
- ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
- ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
- ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
- ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
- ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
- }
- for col in sorted(desired_columns & existing_columns):
- if col == "id":
- continue
- desired_sql = (model_fields.get(col) or "").upper()
- currentType = existing_column_types.get(col, "")
- migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql))
- if migration:
- castExpr = migration.replace("{col}", col)
+ for col in sorted(desired_columns - existing_columns):
+ if col in ["id"]:
+ continue
+ sql_type = model_fields.get(col)
+ if not sql_type:
+ sql_type = "TEXT"
try:
- cursor.execute('SAVEPOINT col_migrate')
cursor.execute(
- f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}'
+ f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
)
- cursor.execute('RELEASE SAVEPOINT col_migrate')
logger.info(
- f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'"
+ f"Added missing column '{col}' ({sql_type}) to '{table}'"
)
- except Exception as alter_err:
- cursor.execute('ROLLBACK TO SAVEPOINT col_migrate')
+ except Exception as add_err:
logger.warning(
- f"Could not migrate column '{col}' on '{table}': {alter_err}"
+ f"Could not add column '{col}' to '{table}': {add_err}"
)
- except Exception as ensure_err:
- logger.warning(
- f"Could not ensure columns for existing table '{table}': {ensure_err}"
- )
- self.connection.commit()
+ # Column type migrations for existing tables.
+ # TEXT→DOUBLE PRECISION handles three value shapes:
+ # 1. NULL / empty string → NULL
+ # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT
+ # 3. Plain numeric string like "3.14" → direct cast
+ _TEXT_TO_DOUBLE = (
+ 'DOUBLE PRECISION USING CASE'
+ ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL'
+ ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\''
+ ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)'
+ ' ELSE NULLIF("{col}", \'\')::double precision'
+ ' END'
+ )
+ _SAFE_TYPE_CHANGES = {
+ ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
+ ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
+ ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
+ ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
+ ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
+ ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
+ }
+ for col in sorted(desired_columns & existing_columns):
+ if col == "id":
+ continue
+ desired_sql = (model_fields.get(col) or "").upper()
+ currentType = existing_column_types.get(col, "")
+ migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql))
+ if migration:
+ castExpr = migration.replace("{col}", col)
+ try:
+ cursor.execute('SAVEPOINT col_migrate')
+ cursor.execute(
+ f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}'
+ )
+ cursor.execute('RELEASE SAVEPOINT col_migrate')
+ logger.info(
+ f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'"
+ )
+ except Exception as alter_err:
+ cursor.execute('ROLLBACK TO SAVEPOINT col_migrate')
+ logger.warning(
+ f"Could not migrate column '{col}' on '{table}': {alter_err}"
+ )
+ except Exception as ensure_err:
+ logger.warning(
+ f"Could not ensure columns for existing table '{table}': {ensure_err}"
+ )
return True
except Exception as e:
logger.error(f"Error ensuring table {table} exists: {e}")
- if hasattr(self, "connection") and self.connection:
- self.connection.rollback()
return False
def _ensureVectorExtension(self) -> bool:
@@ -666,17 +886,14 @@ class DatabaseConnector:
if self._vectorExtensionEnabled:
return True
try:
- self._ensure_connection()
- with self.connection.cursor() as cursor:
- cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
self._vectorExtensionEnabled = True
logger.info("pgvector extension enabled")
return True
except Exception as e:
logger.error(f"Failed to enable pgvector extension: {e}")
- if hasattr(self, "connection") and self.connection:
- self.connection.rollback()
return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
@@ -791,22 +1008,19 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return None
- with self.connection.cursor() as cursor:
- cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
- row = cursor.fetchone()
- if not row:
- return None
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
+ row = cursor.fetchone()
+ if not row:
+ return None
+ record = dict(row)
- # Convert row to dict and handle JSONB fields
- record = dict(row)
- fields = getModelFields(model_class)
-
- parseRecordFields(record, fields, f"record {recordId}")
-
- return record
+ fields = getModelFields(model_class)
+ parseRecordFields(record, fields, f"record {recordId}")
+ return record
except Exception as e:
logger.error(f"Error loading record {recordId} from table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
@@ -849,14 +1063,12 @@ class DatabaseConnector:
if effective_user_id:
record["sysModifiedBy"] = effective_user_id
- with self.connection.cursor() as cursor:
- self._save_record(cursor, table, recordId, record, model_class)
-
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ self._save_record(cursor, table, recordId, record, model_class)
return True
except Exception as e:
logger.error(f"Error saving record {recordId} to table {table}: {e}")
- self.connection.rollback()
return False
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
@@ -870,33 +1082,32 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return []
- with self.connection.cursor() as cursor:
- cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
- records = [dict(row) for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(model_class)
- modelFields = model_class.model_fields
- for record in records:
- parseRecordFields(record, fields, f"table {table}")
- # Set type-aware defaults for NULL JSONB fields
- for fieldName, fieldType in fields.items():
- if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
- fieldInfo = modelFields.get(fieldName)
- if fieldInfo:
- fieldAnnotation = fieldInfo.annotation
- if (fieldAnnotation == list or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is list)):
- record[fieldName] = []
- elif (fieldAnnotation == dict or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is dict)):
- record[fieldName] = {}
+ fields = getModelFields(model_class)
+ modelFields = model_class.model_fields
+ for record in records:
+ parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
- return records
+ return records
except Exception as e:
logger.error(f"Error loading table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def _registerInitialId(self, table: str, initialId: str) -> bool:
@@ -969,28 +1180,20 @@ class DatabaseConnector:
def getTables(self) -> List[str]:
"""Returns a list of all available tables."""
- tables = []
-
+ tables: List[str] = []
try:
- # Ensure connection is alive
- self._ensure_connection()
-
- if not self.connection or self.connection.closed:
- logger.error("Database connection is not available")
- return tables
-
- with self.connection.cursor() as cursor:
- cursor.execute("""
- SELECT table_name
- FROM information_schema.tables
- WHERE table_schema = 'public'
- ORDER BY table_name
- """)
- rows = cursor.fetchall()
- tables = [row["table_name"] for row in rows]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute("""
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = 'public'
+ ORDER BY table_name
+ """)
+ rows = cursor.fetchall()
+ tables = [row["table_name"] for row in rows]
except Exception as e:
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
-
return tables
def getFields(self, model_class: type) -> List[str]:
@@ -1060,43 +1263,42 @@ class DatabaseConnector:
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
- with self.connection.cursor() as cursor:
- cursor.execute(query, where_values)
- records = [dict(row) for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(query, where_values)
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(model_class)
- modelFields = model_class.model_fields
+ fields = getModelFields(model_class)
+ modelFields = model_class.model_fields
+ for record in records:
+ parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
+
+ if fieldFilter and isinstance(fieldFilter, list):
+ result = []
for record in records:
- parseRecordFields(record, fields, f"table {table}")
- for fieldName, fieldType in fields.items():
- if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
- fieldInfo = modelFields.get(fieldName)
- if fieldInfo:
- fieldAnnotation = fieldInfo.annotation
- if (fieldAnnotation == list or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is list)):
- record[fieldName] = []
- elif (fieldAnnotation == dict or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is dict)):
- record[fieldName] = {}
+ filteredRecord = {}
+ for field in fieldFilter:
+ if field in record:
+ filteredRecord[field] = record[field]
+ result.append(filteredRecord)
+ return result
- # If fieldFilter is available, reduce the fields
- if fieldFilter and isinstance(fieldFilter, list):
- result = []
- for record in records:
- filteredRecord = {}
- for field in fieldFilter:
- if field in record:
- filteredRecord[field] = record[field]
- result.append(filteredRecord)
- return result
-
- return records
+ return records
except Exception as e:
logger.error(f"Error loading records from table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def _buildPaginationClauses(
@@ -1281,35 +1483,36 @@ class DatabaseConnector:
where_clause, order_clause, limit_clause, values, count_values = \
self._buildPaginationClauses(model_class, pagination, recordFilter)
- with self.connection.cursor() as cursor:
- countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
- dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
- cursor.execute(countSql, count_values)
- totalItems = cursor.fetchone()["count"]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
+ dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
+ cursor.execute(countSql, count_values)
+ totalItems = cursor.fetchone()["count"]
- cursor.execute(dataSql, values)
- records = [dict(row) for row in cursor.fetchall()]
+ cursor.execute(dataSql, values)
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(model_class)
- modelFields = model_class.model_fields
- for record in records:
- parseRecordFields(record, fields, f"table {table}")
- for fieldName, fieldType in fields.items():
- if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
- fieldInfo = modelFields.get(fieldName)
- if fieldInfo:
- fieldAnnotation = fieldInfo.annotation
- if (fieldAnnotation == list or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is list)):
- record[fieldName] = []
- elif (fieldAnnotation == dict or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is dict)):
- record[fieldName] = {}
+ fields = getModelFields(model_class)
+ modelFields = model_class.model_fields
+ for record in records:
+ parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
- if fieldFilter and isinstance(fieldFilter, list):
- records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
+ if fieldFilter and isinstance(fieldFilter, list):
+ records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
from modules.routes.routeHelpers import enrichRowsWithFkLabels
enrichRowsWithFkLabels(records, model_class)
@@ -1320,7 +1523,6 @@ class DatabaseConnector:
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e:
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def getDistinctColumnValues(
@@ -1365,25 +1567,24 @@ class DatabaseConnector:
else:
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
- with self.connection.cursor() as cursor:
- cursor.execute(sql, values)
- result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(sql, values)
+ result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
- if includeEmpty:
- emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\''
- if where_clause:
- emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
- else:
- emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
- with self.connection.cursor() as cursor:
- cursor.execute(emptySql, values)
- if cursor.fetchone():
- result.append(None)
+ if includeEmpty:
+ emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\''
+ if where_clause:
+ emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
+ else:
+ emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
+ cursor.execute(emptySql, values)
+ if cursor.fetchone():
+ result.append(None)
return result
except Exception as e:
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def recordCreate(
@@ -1463,33 +1664,33 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return False
- with self.connection.cursor() as cursor:
- # Check if record exists
- cursor.execute(
- f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
- )
- if not cursor.fetchone():
- return False
+ # `getInitialId` opens its own borrow; do it BEFORE we acquire a
+ # connection ourselves so we don't pin two slots concurrently.
+ initialId = self.getInitialId(model_class)
- # Check if it's an initial record
- initialId = self.getInitialId(model_class)
- if initialId is not None and initialId == recordId:
- self._removeInitialId(table)
- logger.info(
- f"Initial ID {recordId} for table {table} has been removed from the system table"
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
)
+ if not cursor.fetchone():
+ return False
- # Delete the record
- cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
+ if initialId is not None and initialId == recordId:
+ # `_removeInitialId` borrows its own conn — done outside
+ # this block on purpose to avoid nested borrows.
+ pass
+ cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
- # No cache to update - database handles consistency
-
- self.connection.commit()
+ if initialId is not None and initialId == recordId:
+ self._removeInitialId(table)
+ logger.info(
+ f"Initial ID {recordId} for table {table} has been removed from the system table"
+ )
return True
except Exception as e:
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
- self.connection.rollback()
return False
def recordCreateBulk(
@@ -1559,16 +1760,11 @@ class DatabaseConnector:
)
try:
- self._ensure_connection()
- with self.connection.cursor() as cursor:
- psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
except Exception as e:
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
- try:
- self.connection.rollback()
- except Exception:
- pass
raise
if self.getInitialId(model_class) is None and normalised:
@@ -1649,26 +1845,21 @@ class DatabaseConnector:
initialId = self.getInitialId(model_class)
try:
- self._ensure_connection()
- with self.connection.cursor() as cursor:
- if initialId is not None:
- cursor.execute(
- f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
- [initialId, *params],
- )
- initialIsAffected = cursor.fetchone() is not None
- else:
- initialIsAffected = False
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ if initialId is not None:
+ cursor.execute(
+ f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
+ [initialId, *params],
+ )
+ initialIsAffected = cursor.fetchone() is not None
+ else:
+ initialIsAffected = False
- cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
- deleted = cursor.rowcount or 0
- self.connection.commit()
+ cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
+ deleted = cursor.rowcount or 0
except Exception as e:
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
- try:
- self.connection.rollback()
- except Exception:
- pass
raise
if deleted and initialIsAffected:
@@ -1751,39 +1942,30 @@ class DatabaseConnector:
)
params = [vectorStr] + whereValues + [vectorStr, limit]
- with self.connection.cursor() as cursor:
- cursor.execute(query, params)
- records = [dict(row) for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(query, params)
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(modelClass)
- for record in records:
- parseRecordFields(record, fields, f"semanticSearch {table}")
-
- return records
+ fields = getModelFields(modelClass)
+ for record in records:
+ parseRecordFields(record, fields, f"semanticSearch {table}")
+ return records
except Exception as e:
logger.error(f"Error in semantic search on {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def close(self, forceClose: bool = False):
- """Close the database connection.
-
- Shared cached connectors are intentionally kept open unless forceClose=True.
- This prevents accidental shutdown from interface __del__ methods while
- other requests are still using the same cached connector instance.
+ """No-op for backward compatibility.
+
+ Connections are now owned by the `_PoolRegistry` pool and live for the
+ process lifetime. Pool shutdown happens centrally via `closeAllPools()`
+ from the FastAPI lifespan hook — never from a connector instance.
+ Interface `__del__` paths used to call `close()` to release a per-
+ connector socket; with pooling there is nothing to close here.
"""
- if self._isCachedShared and not forceClose:
- return
- if (
- hasattr(self, "connection")
- and self.connection
- and not self.connection.closed
- ):
- self.connection.close()
+ return
def __del__(self):
- """Cleanup method to close connection."""
- try:
- self.close()
- except Exception:
- pass
+ """Cleanup hook (intentionally no-op — see `close`)."""
+ return
diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py
index 1fbaf06f..0637d0e9 100644
--- a/modules/features/realEstate/interfaceFeatureRealEstate.py
+++ b/modules/features/realEstate/interfaceFeatureRealEstate.py
@@ -342,7 +342,7 @@ class RealEstateObjects:
# If no exact match, try case-insensitive search via SQL query
# This handles cases where the name might have different casing
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@@ -375,7 +375,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@@ -408,7 +408,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@@ -840,7 +840,7 @@ class RealEstateObjects:
# Ensure connection is alive
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
# Execute query
if parameters:
# Use parameterized query for safety
diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py
index 25f022af..273583d9 100644
--- a/modules/interfaces/interfaceDbBilling.py
+++ b/modules/interfaces/interfaceDbBilling.py
@@ -1659,7 +1659,7 @@ class BillingObjects:
try:
appInterface = getAppInterface(self.currentUser)
appInterface.db._ensure_connection()
- with appInterface.db.connection.cursor() as cur:
+ with appInterface.db.borrowCursor() as cur:
if appInterface.db._ensureTableExists(UserInDB):
cur.execute(
'SELECT "id" FROM "UserInDB" WHERE '
@@ -1780,7 +1780,7 @@ class BillingObjects:
try:
self.db._ensure_connection()
- with self.db.connection.cursor() as cur:
+ with self.db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cur.execute(countSql, whereValues)
totalItems = cur.fetchone()["count"]
@@ -1797,10 +1797,7 @@ class BillingObjects:
except Exception as e:
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
- try:
- self.db.connection.rollback()
- except Exception:
- pass
+ # Rollback is handled by `borrowCursor()` context manager on exit.
return {"items": [], "totalItems": 0, "totalPages": 0}
def _buildScopeFilter(
@@ -1872,7 +1869,7 @@ class BillingObjects:
result: Dict[str, Any] = {}
- with self.db.connection.cursor() as cur:
+ with self.db.borrowCursor() as cur:
# 1) Totals
cur.execute(
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
@@ -1947,17 +1944,12 @@ class BillingObjects:
})
result["timeSeries"] = timeSeries
- self.db.connection.commit()
-
+ # Commit/rollback are handled by `borrowCursor()` context manager.
result["_allAccounts"] = allAccounts
return result
except Exception as e:
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
- try:
- self.db.connection.rollback()
- except Exception:
- pass
return self._emptyStats()
@staticmethod
diff --git a/modules/interfaces/interfaceDbKnowledge.py b/modules/interfaces/interfaceDbKnowledge.py
index 31a5af61..d7a445bd 100644
--- a/modules/interfaces/interfaceDbKnowledge.py
+++ b/modules/interfaces/interfaceDbKnowledge.py
@@ -228,6 +228,22 @@ class KnowledgeObjects:
"""Get all ContentChunks for a file."""
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
+ def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
+ """Return a {fileId: chunkCount} mapping for the given file IDs.
+
+ One aggregate query instead of N round trips. Used by RAG inventory
+ to display real chunk counts per DataSource without loading the
+ embedding vectors. Missing file IDs map to 0 in the caller's logic.
+ """
+ if not fileIds:
+ return {}
+ if not self.db._ensureTableExists(ContentChunk):
+ return {}
+ sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
+ with self.db.borrowCursor() as cursor:
+ cursor.execute(sql, (list(fileIds),))
+ return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
+
def deleteContentChunks(self, fileId: str) -> int:
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py
index 6a3c27b5..4dc8a206 100644
--- a/modules/interfaces/interfaceDbManagement.py
+++ b/modules/interfaces/interfaceDbManagement.py
@@ -1221,22 +1221,17 @@ class ComponentObjects:
for item in fileRows
]
- # Single transaction: delete FileData, FileItem, then FileFolder (children first)
- self.db._ensure_connection()
- try:
- with self.db.connection.cursor() as cursor:
- if fileIds:
- cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
- cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
- orderedIds = list(folderIds)
- orderedIds.remove(folderId)
- orderedIds.append(folderId)
- if orderedIds:
- cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
- self.db.connection.commit()
- except Exception:
- self.db.connection.rollback()
- raise
+ # Single transaction: delete FileData, FileItem, then FileFolder (children first).
+ # Commit/rollback are handled by `borrowCursor()` on exit.
+ with self.db.borrowCursor() as cursor:
+ if fileIds:
+ cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
+ cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
+ orderedIds = list(folderIds)
+ orderedIds.remove(folderId)
+ orderedIds.append(folderId)
+ if orderedIds:
+ cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
@@ -1507,7 +1502,7 @@ class ComponentObjects:
try:
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
(uniqueIds,),
@@ -1526,11 +1521,10 @@ class ComponentObjects:
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
deletedFiles = cursor.rowcount
- self.db.connection.commit()
+ # Commit/rollback are handled by `borrowCursor()` context manager.
return {"deletedFiles": deletedFiles}
except Exception as e:
logger.error(f"Error deleting files in batch: {e}")
- self.db.connection.rollback()
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:
diff --git a/modules/interfaces/interfaceRbac.py b/modules/interfaces/interfaceRbac.py
index e41485e0..948609ef 100644
--- a/modules/interfaces/interfaceRbac.py
+++ b/modules/interfaces/interfaceRbac.py
@@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
cursor.execute(query, whereValues)
records = [dict(row) for row in cursor.fetchall()]
@@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
offset = (pagination.page - 1) * pagination.pageSize
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cursor.execute(countSql, countValues)
totalItems = cursor.fetchone()["count"]
@@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
cursor.execute(sql, whereValues)
result = [row["val"] for row in cursor.fetchall()]
@@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
else:
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
cursor.execute(emptySql, whereValues)
if cursor.fetchone():
result.append(None)
@@ -967,7 +967,7 @@ def buildRbacWhereClause(
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
if table == "UserInDB":
try:
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
@@ -994,7 +994,7 @@ def buildRbacWhereClause(
# For UserConnection: Filter via UserMandate junction table
elif table == "UserConnection":
try:
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
diff --git a/modules/routes/routeHelpers.py b/modules/routes/routeHelpers.py
index f1d88e31..bb1386af 100644
--- a/modules/routes/routeHelpers.py
+++ b/modules/routes/routeHelpers.py
@@ -305,7 +305,7 @@ def handleIdsMode(
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
- with db.connection.cursor() as cursor:
+ with db.borrowCursor() as cursor:
cursor.execute(sql, values)
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
except Exception as e:
diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py
index 074b5b85..7c426d77 100644
--- a/modules/routes/routeRagInventory.py
+++ b/modules/routes/routeRagInventory.py
@@ -25,6 +25,18 @@ router = APIRouter(
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
+ """Build per-connection RAG inventory rows.
+
+ Each DataSource row exposes BOTH numbers because they mean different things:
+ * `fileCount` — distinct files indexed (== `FileContentIndex` rows)
+ * `chunkCount` — embedding-sized text fragments (== `ContentChunk` rows,
+ max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
+ actually hits)
+
+ A single PDF typically yields 1 file × 5–100 chunks; legacy UI labelled
+ `len(FileContentIndex)` as "chunks" which was off by 1–2 orders of
+ magnitude and misleading.
+ """
from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex
@@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
- connChunkTotal = len(connIndexRows)
+ connFileTotal = len(connIndexRows)
+ # Map fileId → real chunk count via 1 aggregate query (cheap even for
+ # connections with thousands of files; we never load the vector body).
+ fileIds = [
+ (idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
+ for idx in connIndexRows
+ ]
+ fileIds = [fid for fid in fileIds if fid]
+ chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
+ connChunkTotal = sum(chunkCountByFile.values())
+
+ filesByDs: Dict[str, int] = {}
chunksByDs: Dict[str, int] = {}
- unassigned = 0
+ unassignedFiles = 0
+ unassignedChunks = 0
for idx in connIndexRows:
+ fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
+ chunkCnt = chunkCountByFile.get(fileId, 0)
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef:
- chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
+ filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
+ chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
else:
- unassigned += 1
+ unassignedFiles += 1
+ unassignedChunks += chunkCnt
seen: Dict[str, bool] = {}
dsItems = []
@@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
+ "fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0),
})
- if unassigned > 0 and len(dsItems) > 0:
- perDs = unassigned // len(dsItems)
- remainder = unassigned % len(dsItems)
+ # Spread orphan files (provenance lost) evenly so totals match.
+ if unassignedFiles > 0 and len(dsItems) > 0:
+ perFile = unassignedFiles // len(dsItems)
+ remFile = unassignedFiles % len(dsItems)
+ perChunk = unassignedChunks // len(dsItems)
+ remChunk = unassignedChunks % len(dsItems)
for i, item in enumerate(dsItems):
- item["chunkCount"] += perDs + (1 if i < remainder else 0)
+ item["fileCount"] += perFile + (1 if i < remFile else 0)
+ item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
# Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued.
@@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0),
+ # Surface limit-stop reason so the UI can warn the user
+ # that the index is provably incomplete (and which budget
+ # to raise). None means the walker finished naturally.
+ "stoppedAtLimit": result.get("stoppedAtLimit"),
+ "limits": result.get("limits") or {},
+ "bytesProcessed": result.get("bytesProcessed", 0),
}
if lastError and lastSuccess:
break
@@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
"preferences": getattr(conn, "knowledgePreferences", None) or {},
"dataSources": dsItems,
+ "totalFiles": connFileTotal,
"totalChunks": connChunkTotal,
"runningJobs": runningJobs,
"lastError": lastError,
@@ -139,8 +179,9 @@ def _getInventoryMe(
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
+ totalFiles = sum(c.get("totalFiles", 0) for c in items)
- return {"connections": items, "totals": {"chunks": totalChunks}}
+ return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@@ -170,9 +211,10 @@ def _getInventoryMandate(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
+ totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId)
- return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
+ return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
except HTTPException:
raise
except Exception as e:
@@ -202,8 +244,9 @@ def _getInventoryPlatform(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
+ totalFiles = sum(c.get("totalFiles", 0) for c in items)
- return {"connections": items, "totals": {"chunks": totalChunks}}
+ return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except HTTPException:
raise
except Exception as e:
diff --git a/modules/routes/routeWorkflowDashboard.py b/modules/routes/routeWorkflowDashboard.py
index d83ce1b2..85b372a1 100644
--- a/modules/routes/routeWorkflowDashboard.py
+++ b/modules/routes/routeWorkflowDashboard.py
@@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
GROUP BY "workflowId"
"""
out: dict = {}
- with db.connection.cursor() as cursor:
+ with db.borrowCursor() as cursor:
cursor.execute(sql, (workflowIds,))
for row in cursor.fetchall():
r = dict(row)
@@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
db._ensure_connection()
- with db.connection.cursor() as cursor:
+ with db.borrowCursor() as cursor:
cursor.execute(countSql, countValues)
totalItems = int(cursor.fetchone()["cnt"])
diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
index 4fbea490..bdb3d23b 100644
--- a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
+++ b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
@@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
- """Reuse a pooled DB connector for the given feature database."""
+ """Reuse a pooled DB connector for the given feature database.
+
+ The underlying psycopg2 connections live in the central pool
+ (`_PoolRegistry`) and are recreated on demand if they go stale; we just
+ need to keep the lightweight connector wrapper around.
+ """
if featureDbName in _featureDbConnPool:
- conn = _featureDbConnPool[featureDbName]
- try:
- if conn.connection and not conn.connection.closed:
- return conn
- except Exception as e:
- logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
- _featureDbConnPool.pop(featureDbName, None)
+ return _featureDbConnPool[featureDbName]
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
index 8bfa2628..959e42c9 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
@@ -68,6 +68,9 @@ class ClickupBootstrapResult:
workspaces: int = 0
lists: int = 0
errors: List[str] = field(default_factory=list)
+ # First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
+ # Drives the same UI banner as the file-walker bootstraps.
+ stoppedAtLimit: Optional[str] = None
def _syntheticTaskId(connectionId: str, taskId: str) -> str:
@@ -225,6 +228,7 @@ async def bootstrapClickup(
cancelled = False
for ds in dataSources:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
+ _recordLimitStop(result, "maxTasks", "dataSource", limits)
break
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
cancelled = True
@@ -243,8 +247,11 @@ async def bootstrapClickup(
clickupScope=limits.clickupScope,
)
+ if len(teams) > dsLimits.maxWorkspaces:
+ _recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
for team in teams[:dsLimits.maxWorkspaces]:
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
+ _recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
break
teamId = str(team.get("id", "") or "")
if not teamId:
@@ -351,6 +358,7 @@ async def _walkTeam(
for lst in listsCollected:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
+ _recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return
@@ -407,6 +415,7 @@ async def _walkList(
for task in tasks:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
+ _recordLimitStop(result, "maxTasks", f"list={listId}", limits)
return
if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
result.skippedPolicy += 1
@@ -529,13 +538,37 @@ async def _ingestTask(
)
+def _recordLimitStop(
+ result: ClickupBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: ClickupBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """See subConnectorSyncSharepoint._recordLimitStop for semantics."""
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxTasks": limits.maxTasks,
+ "maxWorkspaces": limits.maxWorkspaces,
+ "maxListsPerWorkspace": limits.maxListsPerWorkspace,
+ }
+ logger.warning(
+ "clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.skippedDuplicate,
+ )
+
+
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.workspaces, result.lists, durationMs,
+ result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "clickup",
@@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"workspaces": result.workspaces,
"lists": result.lists,
"durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"lists": result.lists,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxTasks": MAX_TASKS_DEFAULT,
+ "maxWorkspaces": MAX_WORKSPACES_DEFAULT,
+ "maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
+ "maxAgeDays": MAX_AGE_DAYS_DEFAULT,
+ },
}
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
index 5dd1bd8b..e27abacb 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
@@ -61,6 +61,8 @@ class GdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
+ # See SharepointBootstrapResult.stoppedAtLimit — same semantics.
+ stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@@ -265,8 +267,10 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
+ _recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
+ _recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
@@ -276,6 +280,9 @@ async def _walkFolder(
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
+ if depth + 1 > limits.maxDepth:
+ _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
+ continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@@ -298,6 +305,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
+ _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
modifiedTime = metadata.get("modifiedTime")
@@ -470,13 +478,38 @@ async def _ingestOne(
await asyncio.sleep(0)
+def _recordLimitStop(
+ result: GdriveBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: GdriveBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """See subConnectorSyncSharepoint._recordLimitStop for semantics."""
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxItems": limits.maxItems,
+ "maxBytes": limits.maxBytes,
+ "maxDepth": limits.maxDepth,
+ "maxFileSize": limits.maxFileSize,
+ }
+ logger.warning(
+ "gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.bytesProcessed,
+ )
+
+
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.bytesProcessed, durationMs,
+ result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "gdrive",
@@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"failed": result.failed,
"bytes": result.bytesProcessed,
"durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxItems": MAX_ITEMS_DEFAULT,
+ "maxBytes": MAX_BYTES_DEFAULT,
+ "maxFileSize": MAX_FILE_SIZE_DEFAULT,
+ "maxDepth": MAX_DEPTH_DEFAULT,
+ },
}
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
index e656abe8..dcf19e39 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
@@ -53,6 +53,8 @@ class KdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
+ # See SharepointBootstrapResult.stoppedAtLimit — same semantics.
+ stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@@ -232,14 +234,19 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
+ _recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
+ _recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
+ if depth + 1 > limits.maxDepth:
+ _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
+ continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@@ -262,6 +269,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
+ _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@@ -415,17 +423,42 @@ async def _ingestOne(
await asyncio.sleep(0)
+def _recordLimitStop(
+ result: KdriveBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: KdriveBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """See subConnectorSyncSharepoint._recordLimitStop for semantics."""
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxItems": limits.maxItems,
+ "maxBytes": limits.maxBytes,
+ "maxDepth": limits.maxDepth,
+ "maxFileSize": limits.maxFileSize,
+ }
+ logger.warning(
+ "kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.bytesProcessed,
+ )
+
+
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
- durationMs,
+ durationMs, result.stoppedAtLimit or "none",
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
"connectionId": connectionId, "indexed": result.indexed,
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
- "failed": result.failed, "durationMs": durationMs},
+ "failed": result.failed, "durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit},
)
return {
"connectionId": result.connectionId,
@@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxItems": MAX_ITEMS_DEFAULT,
+ "maxBytes": MAX_BYTES_DEFAULT,
+ "maxFileSize": MAX_FILE_SIZE_DEFAULT,
+ "maxDepth": MAX_DEPTH_DEFAULT,
+ },
}
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
index 892e41ba..e06fd36b 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
@@ -59,6 +59,10 @@ class SharepointBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
+ # First budget that hit zero; None means the walk completed naturally.
+ # Surfaces in the bootstrap result so the RAG inventory UI can warn the
+ # user that the corpus is incomplete and tell them which knob to turn.
+ stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@@ -259,14 +263,22 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
+ _recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
+ _recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
+ if depth + 1 > limits.maxDepth:
+ # We stop descending here but keep walking siblings.
+ # Record once per bootstrap so the UI shows "maxDepth" even
+ # if other budgets aren't exhausted yet.
+ _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
+ continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@@ -289,6 +301,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
+ _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@@ -443,13 +456,44 @@ async def _ingestOne(
await asyncio.sleep(0)
+def _recordLimitStop(
+ result: SharepointBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: SharepointBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
+ maxDepth) only record when no hard limit has yet stopped the run, so the UI
+ surfaces the most important reason.
+
+ Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
+ soft limit — once a hard cap is hit, the corpus is provably incomplete.
+ """
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxItems": limits.maxItems,
+ "maxBytes": limits.maxBytes,
+ "maxDepth": limits.maxDepth,
+ "maxFileSize": limits.maxFileSize,
+ }
+ logger.warning(
+ "sharepoint walker hit %s=%s at %s — partial index "
+ "(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.bytesProcessed,
+ )
+
+
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
- durationMs,
+ durationMs, result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "sharepoint",
@@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"skippedPolicy": result.skippedPolicy,
"failed": result.failed,
"durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxItems": MAX_ITEMS_DEFAULT,
+ "maxBytes": MAX_BYTES_DEFAULT,
+ "maxFileSize": MAX_FILE_SIZE_DEFAULT,
+ "maxDepth": MAX_DEPTH_DEFAULT,
+ },
}
diff --git a/modules/shared/configuration.py b/modules/shared/configuration.py
index 721ce448..15646962 100644
--- a/modules/shared/configuration.py
+++ b/modules/shared/configuration.py
@@ -12,7 +12,8 @@ import logging
import json
import base64
import time
-from typing import Any, Dict, Optional
+import threading
+from typing import Any, Dict, Optional, Tuple
from pathlib import Path
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
@@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
# Structure: {user_id: {key_name: [timestamps]}}
_decryption_attempts = {}
+# Process-wide plaintext cache for decrypted secrets.
+# Key: the encrypted ciphertext (which already includes env prefix).
+# Value: (expiresAtMonotonic, plaintext).
+# TTL is short enough that key rotation propagates quickly, long enough that
+# hot DB-init paths (every API call building a connector) don't blow the
+# decryption rate limit. 60s is a deliberate compromise.
+_DECRYPTION_CACHE_TTL_S = 60.0
+_decryption_cache: Dict[str, Tuple[float, str]] = {}
+_decryption_cache_lock = threading.Lock()
+
def _getMasterKey(envType: str = None) -> bytes:
"""
Get the master key for the specified environment.
@@ -486,25 +497,43 @@ def encryptValue(value: str, envType: str = None, userId: str = "system", keyNam
def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "unknown") -> str:
"""
Decrypt a value using the master key for the current environment.
-
+
+ A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
+ first. The 10/sec rate-limit on cache misses still protects against
+ brute-force attacks; cache HITS bypass it because they are not actual
+ cryptographic operations — they just return the result of an earlier
+ successful decrypt. Without this cache, hot paths like
+ `mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
+ per walker DB call) trigger the rate limit and surface as
+ "Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
+ ERRORs in the RAG inventory UI route.
+
Args:
encryptedValue: The encrypted value with prefix
userId: The user ID making the request (default: "system")
keyName: The name of the key being decrypted (default: "unknown")
-
+
Returns:
str: The decrypted plain text value
-
+
Raises:
ValueError: If decryption fails
"""
if not _isEncryptedValue(encryptedValue):
return encryptedValue # Return as-is if not encrypted
-
- # Check rate limiting (10 per second per user per key)
+
+ # Cache lookup BEFORE the rate-limit check: a cache hit is not a new
+ # cryptographic operation and must not be throttled.
+ now = time.monotonic()
+ with _decryption_cache_lock:
+ cached = _decryption_cache.get(encryptedValue)
+ if cached is not None and cached[0] > now:
+ return cached[1]
+
+ # Cache miss → real decrypt → apply rate limit.
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
-
+
try:
# Extract environment type from prefix
if encryptedValue.startswith('DEV_ENC:'):
@@ -536,7 +565,7 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
encryptedBytes = base64.urlsafe_b64decode(encryptedPart.encode('utf-8'))
decryptedBytes = fernet.decrypt(encryptedBytes)
decryptedValue = decryptedBytes.decode('utf-8')
-
+
# Log audit event for decryption
try:
from modules.shared.auditLogger import audit_logger
@@ -549,11 +578,25 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
except Exception:
# Don't fail if audit logging fails
pass
-
+
+ # Populate cache so subsequent reads of the same ciphertext don't
+ # re-decrypt (and don't consume rate-limit budget).
+ with _decryption_cache_lock:
+ _decryption_cache[encryptedValue] = (
+ time.monotonic() + _DECRYPTION_CACHE_TTL_S,
+ decryptedValue,
+ )
+
return decryptedValue
-
+
except Exception as e:
raise ValueError(f"Decryption failed: {e}")
+
+def clearDecryptionCache() -> None:
+ """Drop all cached plaintext secrets. Call after key rotation or in tests."""
+ with _decryption_cache_lock:
+ _decryption_cache.clear()
+
# Create the global APP_CONFIG instance
APP_CONFIG = Configuration()
\ No newline at end of file
diff --git a/modules/shared/dbMultiTenantOptimizations.py b/modules/shared/dbMultiTenantOptimizations.py
index c178c376..9b5a15b4 100644
--- a/modules/shared/dbMultiTenantOptimizations.py
+++ b/modules/shared/dbMultiTenantOptimizations.py
@@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
logger.debug(f"_ensureUamTablesMatchModels: {e}")
-def _getConnection(dbConnector):
- """Get a connection from the DatabaseConnector.
-
- Ensures the connection is alive and returns it.
- Commits any pending transaction first to avoid blocking.
+from contextlib import contextmanager
+
+
+@contextmanager
+def _borrowDbConn(dbConnector):
+ """Borrow a pooled connection from the DatabaseConnector.
+
+ Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
+ so each CREATE statement is its own transaction (DDL on a managed
+ connection blocks waiting for COMMIT). This helper preserves that
+ behaviour on top of the pool: borrow a connection, flip it to autocommit,
+ yield it, and restore the previous state before returning it to the pool.
"""
- dbConnector._ensure_connection()
- conn = dbConnector.connection
- # Commit any pending transaction to avoid blocking
- try:
- conn.commit()
- except Exception:
- pass # Ignore if nothing to commit
- return conn
+ with dbConnector.borrowConn() as conn:
+ try:
+ previousAutocommit = conn.autocommit
+ except Exception:
+ previousAutocommit = False
+ try:
+ conn.autocommit = True
+ except Exception as e:
+ logger.debug(f"Could not set autocommit on borrowed connection: {e}")
+ try:
+ yield conn
+ finally:
+ try:
+ conn.autocommit = previousAutocommit
+ except Exception:
+ pass
# =============================================================================
@@ -174,73 +189,42 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
}
try:
- # Get a connection from the connector
- conn = _getConnection(dbConnector)
-
- # Save and set autocommit state
- try:
- originalAutocommit = conn.autocommit
- except Exception:
- originalAutocommit = False
-
- try:
- conn.autocommit = True
- except Exception as autoErr:
- logger.debug(f"Could not set autocommit: {autoErr}")
-
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
-
- try:
+
+ with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
- # Apply indexes
results["indexesCreated"] = _applyIndexes(cursor, tables)
-
- # Apply foreign keys
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
-
- # Apply immutable triggers
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
-
- logger.info(
- f"Multi-tenant optimizations applied: "
- f"{results['indexesCreated']} indexes, "
- f"{results['triggersCreated']} triggers, "
- f"{results['foreignKeysCreated']} foreign keys"
- )
- finally:
- # Restore original autocommit state
- try:
- conn.autocommit = originalAutocommit
- except Exception:
- pass
-
+
+ logger.info(
+ f"Multi-tenant optimizations applied: "
+ f"{results['indexesCreated']} indexes, "
+ f"{results['triggersCreated']} triggers, "
+ f"{results['foreignKeysCreated']} foreign keys"
+ )
+
except Exception as e:
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
results["errors"].append(str(e))
-
+
return results
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
"""Apply only indexes (lighter operation, safe for frequent calls)."""
try:
- conn = _getConnection(dbConnector)
- originalAutocommit = conn.autocommit
- conn.autocommit = True
-
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
-
- try:
+
+ with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
return _applyIndexes(cursor, tables)
- finally:
- conn.autocommit = originalAutocommit
except Exception as e:
logger.error(f"Error applying indexes: {e}")
return 0
@@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
}
try:
- conn = _getConnection(dbConnector)
- with conn.cursor() as cursor:
+ with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
# Check regular indexes
for tableName, indexName, _ in _INDEXES:
if _tableExists(cursor, tableName):
diff --git a/modules/shared/gdprDeletion.py b/modules/shared/gdprDeletion.py
index 99e09313..45a9ea43 100644
--- a/modules/shared/gdprDeletion.py
+++ b/modules/shared/gdprDeletion.py
@@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
ORDER BY ordinal_position
"""
- cursor = dbConnector.connection.cursor()
- cursor.execute(query, (tableName,))
- columns = [row[0] for row in cursor.fetchall()]
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query, (tableName,))
+ columns = [row[0] for row in cursor.fetchall()]
return columns
except Exception as e:
logger.error(f"Error getting columns for table {tableName}: {e}")
@@ -92,29 +90,26 @@ def _getAllTables(dbConnector) -> List[str]:
ORDER BY table_name
"""
- cursor = dbConnector.connection.cursor()
- cursor.execute(query)
- allTables = [row[0] for row in cursor.fetchall()]
-
- # Get foreign key relationships to determine dependency order
- fkQuery = """
- SELECT
- tc.table_name,
- ccu.table_name AS foreign_table_name
- FROM information_schema.table_constraints AS tc
- JOIN information_schema.key_column_usage AS kcu
- ON tc.constraint_name = kcu.constraint_name
- AND tc.table_schema = kcu.table_schema
- JOIN information_schema.constraint_column_usage AS ccu
- ON ccu.constraint_name = tc.constraint_name
- AND ccu.table_schema = tc.table_schema
- WHERE tc.constraint_type = 'FOREIGN KEY'
- AND tc.table_schema = 'public'
- """
-
- cursor.execute(fkQuery)
- foreignKeys = cursor.fetchall()
- cursor.close()
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query)
+ allTables = [row[0] for row in cursor.fetchall()]
+
+ fkQuery = """
+ SELECT
+ tc.table_name,
+ ccu.table_name AS foreign_table_name
+ FROM information_schema.table_constraints AS tc
+ JOIN information_schema.key_column_usage AS kcu
+ ON tc.constraint_name = kcu.constraint_name
+ AND tc.table_schema = kcu.table_schema
+ JOIN information_schema.constraint_column_usage AS ccu
+ ON ccu.constraint_name = tc.constraint_name
+ AND ccu.table_schema = tc.table_schema
+ WHERE tc.constraint_type = 'FOREIGN KEY'
+ AND tc.table_schema = 'public'
+ """
+ cursor.execute(fkQuery)
+ foreignKeys = cursor.fetchall()
# Build dependency graph (child -> parent mapping)
dependencies = {}
@@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
# Fallback: return simple list without ordering
try:
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
- cursor = dbConnector.connection.cursor()
- cursor.execute(query)
- tables = [row[0] for row in cursor.fetchall()]
- cursor.close()
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query)
+ tables = [row[0] for row in cursor.fetchall()]
return [t for t in tables if t not in PROTECTED_TABLES]
except Exception:
return []
@@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
AND i.indisprimary
"""
- cursor = dbConnector.connection.cursor()
- cursor.execute(query, (tableName,))
- pkColumns = [row[0] for row in cursor.fetchall()]
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query, (tableName,))
+ pkColumns = [row[0] for row in cursor.fetchall()]
return pkColumns
except Exception as e:
logger.debug(f"Could not get primary key for {tableName}: {e}")
@@ -229,21 +221,15 @@ def _findUserReferencesInTable(
return {}
references = {}
- cursor = dbConnector.connection.cursor()
-
- for userColumn in userColumns:
- # Build SELECT for primary key columns
- pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
- query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
-
- cursor.execute(query, (userId,))
- recordKeys = cursor.fetchall()
-
- if recordKeys:
- references[userColumn] = recordKeys
- logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
-
- cursor.close()
+ with dbConnector.borrowCursor() as cursor:
+ for userColumn in userColumns:
+ pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
+ query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
+ cursor.execute(query, (userId,))
+ recordKeys = cursor.fetchall()
+ if recordKeys:
+ references[userColumn] = recordKeys
+ logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
return references
except Exception as e:
@@ -277,42 +263,35 @@ def _anonymizeRecords(
return 0
try:
- cursor = dbConnector.connection.cursor()
+ # Resolve column metadata once outside the borrow block (it borrows its
+ # own connection internally).
+ columns = _getTableColumns(dbConnector, tableName)
+ hasModifiedAt = "sysModifiedAt" in columns
+
count = 0
-
- for recordKey in recordKeys:
- # Build WHERE clause for primary key
- whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
-
- # Check if table has sysModifiedAt column
- columns = _getTableColumns(dbConnector, tableName)
- hasModifiedAt = "sysModifiedAt" in columns
-
- if hasModifiedAt:
- query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
- params = [anonymousValue, getUtcTimestamp()]
- else:
- query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
- params = [anonymousValue]
-
- # Add primary key values to params
- if isinstance(recordKey, tuple):
- params.extend(recordKey)
- else:
- params.append(recordKey)
-
- cursor.execute(query, params)
- count += cursor.rowcount
-
- dbConnector.connection.commit()
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ for recordKey in recordKeys:
+ whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
+ if hasModifiedAt:
+ query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
+ params = [anonymousValue, getUtcTimestamp()]
+ else:
+ query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
+ params = [anonymousValue]
+
+ if isinstance(recordKey, tuple):
+ params.extend(recordKey)
+ else:
+ params.append(recordKey)
+
+ cursor.execute(query, params)
+ count += cursor.rowcount
+
logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
return count
-
+
except Exception as e:
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
- dbConnector.connection.rollback()
return 0
@@ -338,32 +317,23 @@ def _deleteRecords(
return 0
try:
- cursor = dbConnector.connection.cursor()
count = 0
-
- for recordKey in recordKeys:
- # Build WHERE clause for primary key
- whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
- query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
-
- # Prepare params
- if isinstance(recordKey, tuple):
- params = list(recordKey)
- else:
- params = [recordKey]
-
- cursor.execute(query, params)
- count += cursor.rowcount
-
- dbConnector.connection.commit()
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ for recordKey in recordKeys:
+ whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
+ query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
+ if isinstance(recordKey, tuple):
+ params = list(recordKey)
+ else:
+ params = [recordKey]
+ cursor.execute(query, params)
+ count += cursor.rowcount
+
logger.info(f"Deleted {count} records from {tableName}")
return count
-
+
except Exception as e:
logger.error(f"Error deleting records from {tableName}: {e}")
- dbConnector.connection.rollback()
return 0
diff --git a/scripts/stage0_filefolder_schema_check.py b/scripts/stage0_filefolder_schema_check.py
index 861d8671..d172e19c 100644
--- a/scripts/stage0_filefolder_schema_check.py
+++ b/scripts/stage0_filefolder_schema_check.py
@@ -25,7 +25,7 @@ if not c or not c.connection:
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
raise SystemExit(2)
-cur = c.connection.cursor()
+cur = c.borrowCursor()
def _scalar(cur):
diff --git a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
index 4f98ef4a..57094760 100644
--- a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
+++ b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
@@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
These tests pin the new contract: empty result sets still return ``[]`` /
``None`` (normal), but any exception inside the query path propagates as
-``DatabaseQueryError`` with the table name attached. The transaction is
-rolled back so the connection is usable for subsequent queries.
+``DatabaseQueryError`` with the table name attached.
+
+Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
+the connector borrows a connection from `_PoolRegistry` on every call via the
+`borrowConn()` context manager. The tests mock that context manager so the
+fast-fail contract is exercised without requiring a live Postgres server.
"""
from __future__ import annotations
+from contextlib import contextmanager
from unittest.mock import MagicMock
import pytest
@@ -25,7 +30,6 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
DatabaseQueryError,
- _rollbackQuietly,
)
@@ -39,26 +43,44 @@ class DummyTable:
def _makeConnector(cursorBehavior):
- """Build a ``DatabaseConnector`` skeleton with mocked connection/cursor.
+ """Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
``cursorBehavior`` is a callable invoked with the cursor mock so the test
can configure ``execute``/``fetchall``/``fetchone`` per scenario.
+
+ Returns ``(connector, conn, cursor)``:
+ * ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
+ assert that the borrow lifecycle did the right thing.
+ * ``cursor`` is the per-test cursor mock.
"""
connector = DatabaseConnector.__new__(DatabaseConnector)
+
cursor = MagicMock()
+ cursorBehavior(cursor)
+
cursorContext = MagicMock()
cursorContext.__enter__ = MagicMock(return_value=cursor)
cursorContext.__exit__ = MagicMock(return_value=False)
- connection = MagicMock()
- connection.cursor.return_value = cursorContext
- connector.connection = connection
+ conn = MagicMock()
+ conn.cursor.return_value = cursorContext
+
+ @contextmanager
+ def fakeBorrow():
+ try:
+ yield conn
+ except Exception:
+ conn.rollback()
+ raise
+ else:
+ conn.commit()
+
+ connector.borrowConn = fakeBorrow
connector._ensureTableExists = MagicMock(return_value=True)
connector._systemTableName = "_system"
- cursorBehavior(cursor)
- return connector, connection, cursor
+ return connector, conn, cursor
class TestGetRecordsetFailLoud:
@@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchall.return_value = []
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
result = connector.getRecordset(DummyTable)
assert result == []
- connection.rollback.assert_not_called()
+ conn.rollback.assert_not_called()
+ conn.commit.assert_called_once()
def test_dictAdaptErrorRaisesDatabaseQueryError(self):
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
@@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
cursor.execute.side_effect = psycopg2.ProgrammingError(
"can't adapt type 'dict'"
)
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(
@@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
assert excinfo.value.table == "DummyTable"
assert "can't adapt type 'dict'" in str(excinfo.value)
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
- connection.rollback.assert_called_once()
+ conn.rollback.assert_called_once()
def test_missingColumnRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
'column "wat" does not exist'
)
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
assert "wat" in str(excinfo.value)
- connection.rollback.assert_called_once()
+ conn.rollback.assert_called_once()
def test_operationalErrorRaisesDatabaseQueryError(self):
"""Connection lost mid-query is also a real failure that must propagate."""
def behavior(cursor):
cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError):
connector.getRecordset(DummyTable)
- connection.rollback.assert_called_once()
+ conn.rollback.assert_called_once()
class TestGetRecordFailLoud:
@@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchone.return_value = None
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
result = connector.getRecord(DummyTable, "missing-id")
assert result is None
- connection.rollback.assert_not_called()
+ conn.rollback.assert_not_called()
+ conn.commit.assert_called_once()
def test_queryErrorRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
'relation "DummyTable" does not exist'
)
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecord(DummyTable, "any-id")
assert excinfo.value.table == "DummyTable"
- connection.rollback.assert_called_once()
-
-
-class TestRollbackQuietly:
- def test_rollsBackOnLiveConnection(self):
- connection = MagicMock()
- _rollbackQuietly(connection)
- connection.rollback.assert_called_once()
-
- def test_swallowsRollbackError(self):
- """Rollback failure must not mask the original query error."""
- connection = MagicMock()
- connection.rollback.side_effect = RuntimeError("rollback failed")
- _rollbackQuietly(connection)
-
- def test_noopOnNoneConnection(self):
- _rollbackQuietly(None)
+ conn.rollback.assert_called_once()
diff --git a/tests/unit/connectors/test_connectorDbPostgre_pool.py b/tests/unit/connectors/test_connectorDbPostgre_pool.py
new file mode 100644
index 00000000..9c389add
--- /dev/null
+++ b/tests/unit/connectors/test_connectorDbPostgre_pool.py
@@ -0,0 +1,304 @@
+# Copyright (c) 2026 Patrick Motsch
+# All rights reserved.
+"""Concurrency tests for the PostgreSQL connection pool.
+
+These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
+refactor delivered:
+
+* T1 — 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
+ every call completes within reasonable time (p99 < 2 s).
+* T2 — Two threads `_loadRecord` + `_saveRecord` against the same connector
+ do not corrupt each other's cursors.
+* T3 — `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
+ releases the connection back into the pool cleanly.
+
+The tests need a real PostgreSQL server because the bug they guard against
+only materialises with real psycopg2 sockets — a mocked connection never
+hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
+`.env`) and are auto-skipped when the connection fails (no local Postgres,
+wrong creds, etc.) so `pytest` keeps working in CI-only environments.
+
+To run them locally:
+
+ pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
+
+They use a throwaway database name (`poweron_pool_test_`) and drop it
+in fixture teardown so they leave nothing behind.
+"""
+from __future__ import annotations
+
+import time
+import uuid
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+import psycopg2
+import psycopg2.errors
+import pytest
+from pydantic import Field
+
+from modules.connectors.connectorDbPostgre import (
+ DatabaseConnector,
+ _PoolRegistry,
+ closeAllPools,
+)
+from modules.datamodels.datamodelBase import PowerOnModel
+from modules.shared.configuration import APP_CONFIG
+
+
+def _dbConfig():
+ """Read DB connection params from APP_CONFIG (`.env`).
+
+ Returns ``None`` when host/user/password are not all present so the
+ test module can skip cleanly instead of blowing up at import time.
+ """
+ host = APP_CONFIG.get("DB_HOST")
+ user = APP_CONFIG.get("DB_USER")
+ password = APP_CONFIG.get("DB_PASSWORD_SECRET")
+ port = APP_CONFIG.get("DB_PORT", 5432)
+ if not host or not user or password is None:
+ return None
+ return {"host": host, "user": user, "password": password, "port": int(port)}
+
+
+def _canReachPostgres(cfg) -> bool:
+ """Try a quick connect to the admin DB so we can skip on connection failures."""
+ try:
+ conn = psycopg2.connect(
+ host=cfg["host"], port=cfg["port"], database="postgres",
+ user=cfg["user"], password=cfg["password"], connect_timeout=2,
+ )
+ conn.close()
+ return True
+ except Exception: # noqa: BLE001
+ return False
+
+
+_DB_CFG = _dbConfig()
+pytestmark = pytest.mark.skipif(
+ _DB_CFG is None or not _canReachPostgres(_DB_CFG),
+ reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
+)
+
+
+class PoolTestRow(PowerOnModel):
+ """Tiny model used to exercise the pool — one ID + one payload field."""
+ payload: str = Field(default="", description="Test payload")
+
+
+@pytest.fixture
+def liveConnector():
+ """Spin up a throwaway database, yield a `DatabaseConnector` against it,
+ drop the database afterwards.
+
+ The pool registry is wiped before and after each test so state from one
+ test cannot mask a bug in another.
+ """
+ cfg = _DB_CFG
+ host = cfg["host"]
+ user = cfg["user"]
+ password = cfg["password"]
+ port = cfg["port"]
+ dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
+
+ # Pre-clean: drop any orphan test DB with the same name (shouldn't happen
+ # because we use a unique uuid, but be defensive).
+ adminConn = psycopg2.connect(
+ host=host, port=port, database="postgres", user=user, password=password
+ )
+ adminConn.autocommit = True
+ try:
+ with adminConn.cursor() as cur:
+ cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
+ finally:
+ adminConn.close()
+
+ closeAllPools()
+
+ connector = DatabaseConnector(
+ dbHost=host,
+ dbDatabase=dbName,
+ dbUser=user,
+ dbPassword=password,
+ dbPort=port,
+ )
+ # Seed exactly one row so every concurrent read has a stable target.
+ connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
+
+ yield connector
+
+ # Teardown: tear pools down, then drop the DB.
+ closeAllPools()
+ adminConn = psycopg2.connect(
+ host=host, port=port, database="postgres", user=user, password=password
+ )
+ adminConn.autocommit = True
+ try:
+ with adminConn.cursor() as cur:
+ cur.execute(
+ 'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
+ (dbName,),
+ )
+ cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
+ finally:
+ adminConn.close()
+
+
+class TestPoolConcurrency:
+ def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
+ """Run N worker threads, each issuing M reads. Return (errors, latencies)."""
+ errors: list = []
+ latencies: list = []
+ lock = threading.Lock()
+
+ def worker():
+ for _ in range(callsPerThread):
+ t0 = time.perf_counter()
+ try:
+ rows = liveConnector.getRecordset(PoolTestRow)
+ assert any(r["id"] == "seed" for r in rows)
+ except Exception as e: # noqa: BLE001 — we want every failure mode
+ with lock:
+ errors.append(e)
+ finally:
+ with lock:
+ latencies.append(time.perf_counter() - t0)
+
+ with ThreadPoolExecutor(max_workers=threadCount) as ex:
+ futures = [ex.submit(worker) for _ in range(threadCount)]
+ for f in as_completed(futures):
+ f.result()
+ latencies.sort()
+ return errors, latencies
+
+ def test_50_threads_x_20_reads_no_errors(self, liveConnector):
+ """T1a — STRESS: 50 threads × 20 reads each → 0 errors.
+
+ Pre-pool, this scenario produced either
+ `OperationalError: another command is already in progress` or a
+ deadlock in `recv()` because the threadpool shared one psycopg2
+ socket. With the pool plus `borrowConn`'s bounded wait, every
+ thread eventually gets a connection and completes — even with 30
+ threads queued waiting at any moment (pool max=20).
+ """
+ errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
+ assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
+
+ def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
+ """T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
+
+ 20 threads matches the pool's `max=20` so there is no queueing —
+ every borrow returns immediately. This pins a sanity-level per-call
+ latency budget; pre-pool it was unbounded (recv() never returned).
+
+ The 5 s ceiling is generous on purpose: `getRecordset` calls
+ `_ensureTableExists` which runs two `information_schema` queries
+ for column-additive migration, and under 20-way concurrency on a
+ single Postgres instance that produces a long tail. The hard
+ assertion is `not errors` — the latency check just guarantees
+ nothing hangs indefinitely.
+ """
+ errors, latencies = self._runWorkers(
+ liveConnector, threadCount=20, callsPerThread=50
+ )
+ assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
+ p99 = latencies[int(len(latencies) * 0.99)]
+ assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
+
+ def test_interleaved_load_and_save_no_collision(self, liveConnector):
+ """T2: parallel reads + writes on the same connector → no cursor mix-up.
+
+ Pre-pool the reader could observe a row in mid-write or vice versa
+ because both shared the same cursor. With one connection per borrow,
+ the database's own row-locking is the only contention, and we just
+ need to assert no exceptions.
+ """
+ stopFlag = threading.Event()
+ errors: list = []
+ lock = threading.Lock()
+
+ def reader():
+ while not stopFlag.is_set():
+ try:
+ liveConnector.getRecord(PoolTestRow, "seed")
+ except Exception as e: # noqa: BLE001
+ with lock:
+ errors.append(("read", e))
+
+ def writer():
+ i = 0
+ while not stopFlag.is_set():
+ try:
+ liveConnector.recordModify(
+ PoolTestRow,
+ "seed",
+ {"id": "seed", "payload": f"v{i}"},
+ )
+ i += 1
+ except Exception as e: # noqa: BLE001
+ with lock:
+ errors.append(("write", e))
+
+ threads = [
+ threading.Thread(target=reader, daemon=True),
+ threading.Thread(target=reader, daemon=True),
+ threading.Thread(target=writer, daemon=True),
+ threading.Thread(target=writer, daemon=True),
+ ]
+ for t in threads:
+ t.start()
+ time.sleep(2.0)
+ stopFlag.set()
+ for t in threads:
+ t.join(timeout=3.0)
+
+ assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
+
+ def test_statement_timeout_releases_connection(self, liveConnector):
+ """T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
+
+ The bug we are guarding against: a runaway query with no timeout
+ hung `recv()` forever, the psycopg2 connection was poisoned, and the
+ whole backend became unresponsive once that connection was reused.
+ With `statement_timeout=30000` configured at pool construction the
+ query is cancelled by the server, the borrow context manager rolls
+ back, and the connection returns to the pool — proven by the fact
+ that a follow-up call still succeeds quickly.
+ """
+ # Use a short timeout to keep the test fast — override the pool's
+ # session statement_timeout for one borrow via SET LOCAL.
+ with liveConnector.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute("SET LOCAL statement_timeout = 500")
+ with pytest.raises(psycopg2.errors.QueryCanceled):
+ cursor.execute("SELECT pg_sleep(5)")
+
+ # Follow-up call must succeed quickly: connection is back in the pool.
+ t0 = time.perf_counter()
+ rows = liveConnector.getRecordset(PoolTestRow)
+ elapsed = time.perf_counter() - t0
+ assert any(r["id"] == "seed" for r in rows)
+ assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
+
+
+class TestPoolRegistry:
+ def test_one_pool_per_database_identity(self, liveConnector):
+ """Two connectors against the same (host, db, port) share one pool."""
+ cfg = _DB_CFG
+ pool1 = _PoolRegistry.getPool(
+ dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
+ dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
+ )
+ pool2 = _PoolRegistry.getPool(
+ dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
+ dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
+ )
+ assert pool1 is pool2
+
+ def test_close_all_clears_registry(self, liveConnector):
+ """`closeAllPools()` empties the registry so the next call rebuilds."""
+ # Touch the pool first.
+ liveConnector.getRecordset(PoolTestRow)
+ assert _PoolRegistry._pools, "pool should exist after a real call"
+ closeAllPools()
+ assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"
diff --git a/tests/unit/interfaces/test_folderRbac.py b/tests/unit/interfaces/test_folderRbac.py
index 049f392d..f4b984aa 100644
--- a/tests/unit/interfaces/test_folderRbac.py
+++ b/tests/unit/interfaces/test_folderRbac.py
@@ -68,6 +68,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
+ def borrowCursor(self):
+ """Mimic `DatabaseConnector.borrowCursor()` context manager."""
+ from contextlib import contextmanager
+ from unittest.mock import MagicMock
+
+ @contextmanager
+ def _cm():
+ yield MagicMock()
+ return _cm()
+
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})
diff --git a/tests/unit/routes/test_folder_crud.py b/tests/unit/routes/test_folder_crud.py
index 86eaf480..66bad903 100644
--- a/tests/unit/routes/test_folder_crud.py
+++ b/tests/unit/routes/test_folder_crud.py
@@ -69,6 +69,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
+ def borrowCursor(self):
+ """Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
+ from contextlib import contextmanager
+ from unittest.mock import MagicMock
+
+ @contextmanager
+ def _cm():
+ yield MagicMock()
+ return _cm()
+
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})
From 4064ac0266b67e77625a0726fec1f918f7ac1c88 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Mon, 18 May 2026 07:56:53 +0200
Subject: [PATCH 2/6] fixed toggle icons udb
---
app.py | 3 +
modules/datamodels/datamodelBackgroundJob.py | 11 +
modules/datamodels/datamodelDataSource.py | 41 ++-
.../datamodels/datamodelFeatureDataSource.py | 29 +-
.../trustee/accounting/accountingDataSync.py | 9 +-
modules/features/trustee/mainTrustee.py | 21 ++
.../features/trustee/routeFeatureTrustee.py | 24 +-
.../workspace/routeFeatureWorkspace.py | 1 +
modules/routes/routeDataSources.py | 267 ++++++++++++--
modules/routes/routeJobs.py | 18 +-
modules/routes/routeRagInventory.py | 73 +++-
.../mainBackgroundJobService.py | 40 +-
.../services/serviceChat/mainServiceChat.py | 10 +-
.../serviceKnowledge/_costEstimate.py | 86 +++++
.../serviceKnowledge/_inheritFlags.py | 342 ++++++++++++++++++
.../serviceKnowledge/_progressMessages.py | 23 ++
.../services/serviceKnowledge/_ragLimits.py | 107 ++++++
.../subConnectorIngestConsumer.py | 43 ++-
.../subConnectorSyncClickup.py | 27 +-
.../subConnectorSyncGdrive.py | 31 +-
.../serviceKnowledge/subConnectorSyncGmail.py | 6 +-
.../subConnectorSyncKdrive.py | 31 +-
.../subConnectorSyncOutlook.py | 6 +-
.../subConnectorSyncSharepoint.py | 36 +-
.../serviceKnowledge/subPolicyResolver.py | 70 +---
modules/shared/i18nRegistry.py | 42 +++
scripts/debug_rag_job_result.py | 70 ++++
..._db_migrate_backgroundjob_progress_data.py | 97 +++++
.../script_db_migrate_datasource_inherit.py | 110 ++++++
.../script_db_migrate_datasource_settings.py | 102 ++++++
tests/unit/services/test_costEstimate.py | 55 +++
tests/unit/services/test_inheritFlags.py | 330 +++++++++++++++++
.../test_knowledge_ingest_consumer.py | 39 +-
tests/unit/services/test_ragLimits.py | 79 ++++
34 files changed, 2107 insertions(+), 172 deletions(-)
create mode 100644 modules/serviceCenter/services/serviceKnowledge/_costEstimate.py
create mode 100644 modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py
create mode 100644 modules/serviceCenter/services/serviceKnowledge/_progressMessages.py
create mode 100644 modules/serviceCenter/services/serviceKnowledge/_ragLimits.py
create mode 100644 scripts/debug_rag_job_result.py
create mode 100644 scripts/script_db_migrate_backgroundjob_progress_data.py
create mode 100644 scripts/script_db_migrate_datasource_inherit.py
create mode 100644 scripts/script_db_migrate_datasource_settings.py
create mode 100644 tests/unit/services/test_costEstimate.py
create mode 100644 tests/unit/services/test_inheritFlags.py
create mode 100644 tests/unit/services/test_ragLimits.py
diff --git a/app.py b/app.py
index d94c7dd5..93cc8b79 100644
--- a/app.py
+++ b/app.py
@@ -418,6 +418,9 @@ async def lifespan(app: FastAPI):
registerKnowledgeIngestionConsumer,
)
registerKnowledgeIngestionConsumer()
+ # Side-effect import: registers all walker progress message keys
+ # in the i18n registry so `syncRegistryToDb` picks them up.
+ from modules.serviceCenter.services.serviceKnowledge import _progressMessages # noqa: F401
except Exception as e:
logger.warning(f"KnowledgeIngestionConsumer registration failed (non-critical): {e}")
diff --git a/modules/datamodels/datamodelBackgroundJob.py b/modules/datamodels/datamodelBackgroundJob.py
index fa99ea34..809fb994 100644
--- a/modules/datamodels/datamodelBackgroundJob.py
+++ b/modules/datamodels/datamodelBackgroundJob.py
@@ -96,6 +96,17 @@ class BackgroundJob(PowerOnModel):
description="Human-readable current step (e.g. 'Importing journal entries...')",
json_schema_extra={"label": "Fortschritts-Nachricht"},
)
+ progressMessageData: Optional[Dict[str, Any]] = Field(
+ None,
+ description=(
+ "Structured i18n payload for `progressMessage`. Shape: "
+ "{'key': '', 'params': {...}}. "
+ "Frontend renders via `t(key, params)`; older clients fall back "
+ "to `progressMessage`. Single source of truth — keep `progressMessage` "
+ "as the rendered fallback in the producing language."
+ ),
+ json_schema_extra={"label": "Fortschritts-Nachricht (i18n)"},
+ )
payload: Dict[str, Any] = Field(
default_factory=dict,
diff --git a/modules/datamodels/datamodelDataSource.py b/modules/datamodels/datamodelDataSource.py
index fe3f0442..de32bdf3 100644
--- a/modules/datamodels/datamodelDataSource.py
+++ b/modules/datamodels/datamodelDataSource.py
@@ -62,9 +62,14 @@ class DataSource(PowerOnModel):
description="Owner user ID",
json_schema_extra={"label": "Benutzer-ID", "fk_target": {"db": "poweron_app", "table": "UserInDB", "labelField": "username"}},
)
- ragIndexEnabled: bool = Field(
- default=False,
- description="When true this tree element is indexed into the RAG knowledge store",
+ ragIndexEnabled: Optional[bool] = Field(
+ default=None,
+ description=(
+ "Three-state RAG indexing flag with cascade-inherit semantics. "
+ "None = inherit from nearest ancestor DataSource (path-traversal); "
+ "True/False = explicit override that propagates to descendants. "
+ "Walker computes effective value via getEffectiveFlag()."
+ ),
json_schema_extra={"label": "Im RAG indexieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
)
lastIndexed: Optional[float] = Field(
@@ -72,9 +77,13 @@ class DataSource(PowerOnModel):
description="Timestamp of last successful RAG indexing run",
json_schema_extra={"label": "Letzte Indexierung", "frontend_type": "timestamp"},
)
- scope: str = Field(
- default="personal",
- description="Data visibility scope: personal, featureInstance, mandate, global",
+ scope: Optional[str] = Field(
+ default=None,
+ description=(
+ "Data visibility scope with inherit semantics. "
+ "None = inherit; values: personal, featureInstance, mandate, global. "
+ "Cascade-reset on parent toggle."
+ ),
json_schema_extra={"label": "Sichtbarkeit", "frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
{"value": "personal", "label": "Persönlich"},
{"value": "featureInstance", "label": "Feature-Instanz"},
@@ -82,11 +91,25 @@ class DataSource(PowerOnModel):
{"value": "global", "label": "Global"},
]},
)
- neutralize: bool = Field(
- default=False,
- description="Whether this data source should be neutralized before AI processing",
+ neutralize: Optional[bool] = Field(
+ default=None,
+ description=(
+ "Three-state neutralization flag with cascade-inherit semantics. "
+ "None = inherit from nearest ancestor DataSource (path-traversal); "
+ "True/False = explicit override that propagates to descendants."
+ ),
json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
)
+ settings: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description=(
+ "DataSource-scoped settings (JSON). Currently used keys: "
+ "ragLimits.{maxBytes,maxFileSize,maxItems,maxDepth}. "
+ "Walker reads these directly; missing keys fall back to RAG_LIMITS_DEFAULT "
+ "and are lazily persisted on next bootstrap."
+ ),
+ json_schema_extra={"label": "Einstellungen", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False},
+ )
class ExternalEntry(BaseModel):
diff --git a/modules/datamodels/datamodelFeatureDataSource.py b/modules/datamodels/datamodelFeatureDataSource.py
index dd2c4035..f07a8bda 100644
--- a/modules/datamodels/datamodelFeatureDataSource.py
+++ b/modules/datamodels/datamodelFeatureDataSource.py
@@ -6,7 +6,7 @@ A FeatureDataSource links a FeatureInstance table (DATA_OBJECT) to a workspace
so the agent can query structured feature data (e.g. TrusteePosition rows).
"""
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from modules.datamodels.datamodelBase import PowerOnModel
from modules.shared.i18nRegistry import i18nModel
@@ -55,9 +55,12 @@ class FeatureDataSource(PowerOnModel):
description="Workspace feature instance where this source is used",
json_schema_extra={"label": "Workspace", "fk_target": {"db": "poweron_app", "table": "FeatureInstance", "labelField": "label"}},
)
- scope: str = Field(
- default="personal",
- description="Data visibility scope: personal, featureInstance, mandate, global",
+ scope: Optional[str] = Field(
+ default=None,
+ description=(
+ "Data visibility scope with inherit semantics. "
+ "None = inherit; values: personal, featureInstance, mandate, global."
+ ),
json_schema_extra={"label": "Sichtbarkeit", "frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
{"value": "personal", "label": "Persönlich"},
{"value": "featureInstance", "label": "Feature-Instanz"},
@@ -65,9 +68,12 @@ class FeatureDataSource(PowerOnModel):
{"value": "global", "label": "Global"},
]},
)
- neutralize: bool = Field(
- default=False,
- description="Whether this data source should be neutralized before AI processing",
+ neutralize: Optional[bool] = Field(
+ default=None,
+ description=(
+ "Three-state neutralization flag with cascade-inherit semantics. "
+ "None = inherit; True/False = explicit. Cascade-reset on parent toggle."
+ ),
json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
)
neutralizeFields: Optional[List[str]] = Field(
@@ -80,3 +86,12 @@ class FeatureDataSource(PowerOnModel):
description="Record-level filter applied when querying this table, e.g. {'sessionId': 'abc-123'}",
json_schema_extra={"label": "Datensatzfilter"},
)
+ settings: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description=(
+ "FeatureDataSource-scoped settings (JSON). Currently used keys: "
+ "ragLimits.{maxBytes,maxFileSize,maxItems,maxDepth}. "
+ "Mirror of DataSource.settings so the UDB settings modal can target both."
+ ),
+ json_schema_extra={"label": "Einstellungen", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False},
+ )
diff --git a/modules/features/trustee/accounting/accountingDataSync.py b/modules/features/trustee/accounting/accountingDataSync.py
index 5827dd11..db50d657 100644
--- a/modules/features/trustee/accounting/accountingDataSync.py
+++ b/modules/features/trustee/accounting/accountingDataSync.py
@@ -205,11 +205,16 @@ class AccountingDataSync:
boundary so the UI poll on ``GET /api/jobs/{jobId}`` shows real
movement instead of jumping from 10 % to 100 %. Safe to omit.
"""
- def _progress(pct: int, msg: str) -> None:
+ def _progress(pct: int, msgKey: str, msgParams: Optional[Dict[str, Any]] = None) -> None:
+ """Forward to progressCb using the i18n contract.
+
+ `msgKey` is the German plaintext-as-key; the frontend translates
+ it via `t(key, params)` when rendering.
+ """
if progressCb is None:
return
try:
- progressCb(pct, msg)
+ progressCb(pct, messageKey=msgKey, messageParams=msgParams or {})
except Exception as ex:
logger.warning(f"progressCb failed at {pct}%: {ex}")
from modules.features.trustee.datamodelFeatureTrustee import (
diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py
index 8f725d2f..b3f7cdcf 100644
--- a/modules/features/trustee/mainTrustee.py
+++ b/modules/features/trustee/mainTrustee.py
@@ -12,6 +12,27 @@ from modules.shared.i18nRegistry import t
logger = logging.getLogger(__name__)
+# i18n: register BackgroundJob progress message keys used by routeFeatureTrustee /
+# accountingDataSync. Walker call sites use `progressCb(..., messageKey="…")`
+# without going through `t()`, so we must register each key here as a
+# string-literal `t(...)` call -- per i18n convention `t()` MUST receive a
+# literal so static scanners and the boot-time `syncRegistryToDb` can pick
+# it up. Do NOT collapse these into a loop over a list of variables.
+t("Sync wird vorbereitet ({total} Position(en))...")
+t("Verbindungsaufbau fehlgeschlagen.")
+t("Keine aktive Buchhaltungs-Konfiguration gefunden.")
+t("Position {index}/{total} verarbeitet")
+t("Sync abgeschlossen.")
+t("Initialisiere Import...")
+t("Verbinde mit Buchhaltungssystem...")
+t("Import abgeschlossen.")
+t("Lade Kontenplan...")
+t("Lade Journaleintraege vom Buchhaltungssystem...")
+t("Lade Kunden...")
+t("Lade Lieferanten...")
+t("Lade Kontensaldi vom Buchhaltungssystem...")
+t("Speichere Kontensaldi...")
+
# Feature metadata
FEATURE_CODE = "trustee"
FEATURE_LABEL = t("Treuhand", context="UI")
diff --git a/modules/features/trustee/routeFeatureTrustee.py b/modules/features/trustee/routeFeatureTrustee.py
index 2c9c3328..a71b508f 100644
--- a/modules/features/trustee/routeFeatureTrustee.py
+++ b/modules/features/trustee/routeFeatureTrustee.py
@@ -1644,7 +1644,11 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D
results = []
total = len(positionIds)
- progressCb(2, f"Sync wird vorbereitet ({total} Position(en))...")
+ progressCb(
+ 2,
+ messageKey="Sync wird vorbereitet ({total} Position(en))...",
+ messageParams={"total": total},
+ )
# Resolve connector + plain config once to avoid decryption rate-limits
# (mirrors the optimisation in pushBatchToAccounting). We push positions
@@ -1655,12 +1659,12 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D
connector, plainConfig, configRecord = await bridge._resolveConnectorAndConfig(instanceId)
except Exception as resolveErr:
logger.exception("Accounting push: failed to resolve connector/config")
- progressCb(100, "Verbindungsaufbau fehlgeschlagen.")
+ progressCb(100, messageKey="Verbindungsaufbau fehlgeschlagen.")
raise resolveErr
if not connector or not plainConfig:
results = [SyncResult(success=False, errorMessage="No active accounting configuration found") for _ in positionIds]
- progressCb(100, "Keine aktive Buchhaltungs-Konfiguration gefunden.")
+ progressCb(100, messageKey="Keine aktive Buchhaltungs-Konfiguration gefunden.")
return {
"total": len(results),
"success": 0,
@@ -1680,7 +1684,11 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D
results.append(result)
# Reserve 5..95% for the push loop, keep the tail for summary.
pct = 5 + int(90 * index / total)
- progressCb(pct, f"Position {index}/{total} verarbeitet")
+ progressCb(
+ pct,
+ messageKey="Position {index}/{total} verarbeitet",
+ messageParams={"index": index, "total": total},
+ )
skipped = [r for r in results if not r.success and r.errorMessage and "already synced" in r.errorMessage]
failed = [r for r in results if not r.success and r not in skipped]
@@ -1693,7 +1701,7 @@ async def _trusteeAccountingPushJobHandler(job: Dict[str, Any], progressCb) -> D
"; ".join(r.errorMessage or "unknown" for r in failed[:3]),
)
- progressCb(100, "Sync abgeschlossen.")
+ progressCb(100, messageKey="Sync abgeschlossen.")
return {
"total": len(results),
"success": sum(1 for r in results if r.success),
@@ -1823,10 +1831,10 @@ async def _trusteeAccountingSyncJobHandler(job: Dict[str, Any], progressCb) -> D
payload = job.get("payload") or {}
rootUser = getRootUser()
- progressCb(5, "Initialisiere Import...")
+ progressCb(5, messageKey="Initialisiere Import...")
interface = getInterface(rootUser, mandateId=mandateId, featureInstanceId=instanceId)
sync = AccountingDataSync(interface)
- progressCb(10, "Verbinde mit Buchhaltungssystem...")
+ progressCb(10, messageKey="Verbinde mit Buchhaltungssystem...")
result = await sync.importData(
featureInstanceId=instanceId,
mandateId=mandateId,
@@ -1834,7 +1842,7 @@ async def _trusteeAccountingSyncJobHandler(job: Dict[str, Any], progressCb) -> D
dateTo=payload.get("dateTo"),
progressCb=progressCb,
)
- progressCb(100, "Import abgeschlossen.")
+ progressCb(100, messageKey="Import abgeschlossen.")
return result
diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py
index 4487e5fe..2fa788e8 100644
--- a/modules/features/workspace/routeFeatureWorkspace.py
+++ b/modules/features/workspace/routeFeatureWorkspace.py
@@ -1324,6 +1324,7 @@ async def listWorkspaceConnections(
"externalUsername": conn.get("externalUsername"),
"externalEmail": conn.get("externalEmail"),
"status": status,
+ "knowledgeIngestionEnabled": bool(conn.get("knowledgeIngestionEnabled")),
})
return JSONResponse({"connections": items})
diff --git a/modules/routes/routeDataSources.py b/modules/routes/routeDataSources.py
index ba398008..5dec19c8 100644
--- a/modules/routes/routeDataSources.py
+++ b/modules/routes/routeDataSources.py
@@ -9,11 +9,40 @@ from fastapi import APIRouter, HTTPException, Depends, Path, Request, Body
from modules.auth import limiter, getRequestContext, RequestContext
from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
+from modules.datamodels.datamodelUam import UserConnection
from modules.shared.i18nRegistry import apiRouteContext
routeApiMsg = apiRouteContext("routeDataSources")
logger = logging.getLogger(__name__)
+
+def _ensureConnectionKnowledgeFlag(rootIf, connectionId: str) -> None:
+ """Forward-only sync: if a DataSource gets RAG-activated, ensure the parent
+ UserConnection.knowledgeIngestionEnabled is true.
+
+ Intentionally NOT bidirectional: disabling the last DataSource does NOT
+ auto-clear knowledgeIngestionEnabled, because the consent flag may have
+ been set explicitly via the Connections page / wizard even before any
+ DataSource exists. Only the master switch (`/knowledge-consent`) may
+ clear it.
+ """
+ if not connectionId:
+ return
+ try:
+ currentConn = rootIf.db.getRecord(UserConnection, connectionId)
+ if not currentConn:
+ return
+ if bool(currentConn.get("knowledgeIngestionEnabled")):
+ return
+ rootIf.db.recordModify(UserConnection, connectionId, {"knowledgeIngestionEnabled": True})
+ logger.info(
+ "Auto-enabled knowledgeIngestionEnabled on UserConnection %s "
+ "(triggered by first active DataSource).",
+ connectionId,
+ )
+ except Exception as e:
+ logger.warning("Could not auto-enable knowledgeIngestionEnabled for connection %s: %s", connectionId, e)
+
router = APIRouter(
prefix="/api/datasources",
tags=["Data Sources"],
@@ -45,26 +74,43 @@ def _findSourceRecord(db, sourceId: str):
def _updateDataSourceScope(
request: Request,
sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"),
- scope: str = Body(..., embed=True),
+ scope: Optional[str] = Body(None, embed=True),
context: RequestContext = Depends(getRequestContext),
) -> Dict[str, Any]:
- """Update the scope of a DataSource or FeatureDataSource. Global scope requires sysAdmin."""
- if scope not in _VALID_SCOPES:
- raise HTTPException(status_code=400, detail=f"Invalid scope: {scope}. Must be one of {_VALID_SCOPES}")
+ """Update the scope of a DataSource. Cascade-resets explicit descendants.
- if scope == "global" and not context.isSysAdmin:
- raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope"))
+ `scope=None` resets this node to inherit (no cascade). Global scope
+ requires sysAdmin.
+ """
+ if scope is not None:
+ if scope not in _VALID_SCOPES:
+ raise HTTPException(status_code=400, detail=f"Invalid scope: {scope}. Must be one of {_VALID_SCOPES}")
+ if scope == "global" and not context.isSysAdmin:
+ raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope"))
try:
from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ cascadeResetDescendants,
+ cascadeResetDescendantsFds,
+ )
rootIf = getRootInterface()
rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
rootIf.db.recordModify(model, sourceId, {"scope": scope})
- logger.info("Updated scope=%s for %s %s", scope, model.__name__, sourceId)
- return {"sourceId": sourceId, "scope": scope, "updated": True}
+ cascaded = 0
+ if scope is not None:
+ if model is DataSource:
+ cascaded = cascadeResetDescendants(rootIf, rec, "scope")
+ else:
+ cascaded = cascadeResetDescendantsFds(rootIf, rec, "scope")
+ logger.info(
+ "Updated scope=%s for %s %s (cascade-reset %d descendants)",
+ scope, model.__name__, sourceId, cascaded,
+ )
+ return {"sourceId": sourceId, "scope": scope, "updated": True, "cascadedDescendants": cascaded}
except HTTPException:
raise
except Exception as e:
@@ -77,20 +123,36 @@ def _updateDataSourceScope(
def _updateDataSourceNeutralize(
request: Request,
sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"),
- neutralize: bool = Body(..., embed=True),
+ neutralize: Optional[bool] = Body(None, embed=True),
context: RequestContext = Depends(getRequestContext),
) -> Dict[str, Any]:
- """Toggle the neutralization flag on a DataSource or FeatureDataSource."""
+ """Set neutralize flag on a DataSource. Cascade-resets explicit descendants.
+
+ `neutralize=None` resets this node to inherit (no cascade).
+ """
try:
from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ cascadeResetDescendants,
+ cascadeResetDescendantsFds,
+ )
rootIf = getRootInterface()
rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize})
- logger.info("Updated neutralize=%s for %s %s", neutralize, model.__name__, sourceId)
- return {"sourceId": sourceId, "neutralize": neutralize, "updated": True}
+ cascaded = 0
+ if neutralize is not None:
+ if model is DataSource:
+ cascaded = cascadeResetDescendants(rootIf, rec, "neutralize")
+ else:
+ cascaded = cascadeResetDescendantsFds(rootIf, rec, "neutralize")
+ logger.info(
+ "Updated neutralize=%s for %s %s (cascade-reset %d descendants)",
+ neutralize, model.__name__, sourceId, cascaded,
+ )
+ return {"sourceId": sourceId, "neutralize": neutralize, "updated": True, "cascadedDescendants": cascaded}
except HTTPException:
raise
except Exception as e:
@@ -132,13 +194,14 @@ def _updateNeutralizeFields(
async def _updateDataSourceRagIndex(
request: Request,
sourceId: str = Path(..., description="ID of the DataSource"),
- ragIndexEnabled: bool = Body(..., embed=True),
+ ragIndexEnabled: Optional[bool] = Body(None, embed=True),
context: RequestContext = Depends(getRequestContext),
) -> Dict[str, Any]:
- """Toggle RAG indexing for a DataSource.
+ """Set RAG indexing flag on a DataSource. Cascade-resets explicit descendants.
- true: sets flag + enqueues mini-bootstrap for this DataSource only.
- false: sets flag + synchronously purges all chunks from this DataSource.
+ `ragIndexEnabled=None` resets this node to inherit (no cascade, no purge,
+ no bootstrap — the node simply follows its ancestor chain afterwards).
+ `True` enqueues a mini-bootstrap. `False` synchronously purges chunks.
Must be `async def` so `await startJob(...)` registers `_runJob` in the
main event loop. Sync route → worker thread → temporary loop closes
@@ -146,18 +209,26 @@ async def _updateDataSourceRagIndex(
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import cascadeResetDescendants
rootIf = getRootInterface()
rec = rootIf.db.getRecord(DataSource, sourceId)
if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
rootIf.db.recordModify(DataSource, sourceId, {"ragIndexEnabled": ragIndexEnabled})
- logger.info("Updated ragIndexEnabled=%s for DataSource %s", ragIndexEnabled, sourceId)
+ cascaded = 0
+ if ragIndexEnabled is not None:
+ cascaded = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled")
+ logger.info(
+ "Updated ragIndexEnabled=%s for DataSource %s (cascade-reset %d descendants)",
+ ragIndexEnabled, sourceId, cascaded,
+ )
- if ragIndexEnabled:
+ connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
+ if ragIndexEnabled is True:
+ _ensureConnectionKnowledgeFlag(rootIf, connectionId)
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
- connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
conn = rootIf.getUserConnectionById(connectionId) if connectionId else None
authority = ""
if conn:
@@ -168,7 +239,7 @@ async def _updateDataSourceRagIndex(
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]},
triggeredBy=str(context.user.id),
)
- else:
+ elif ragIndexEnabled is False:
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId)
logger.info("Purged %d index rows / %d chunks for DataSource %s",
@@ -182,12 +253,164 @@ async def _updateDataSourceRagIndex(
mandateId=context.mandateId,
category=AuditCategory.PERMISSION.value,
action="rag_index_toggled",
- details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled}),
+ details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "cascadedDescendants": cascaded}),
)
- return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True}
+ return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True, "cascadedDescendants": cascaded}
except HTTPException:
raise
except Exception as e:
logger.error("Error updating datasource ragIndexEnabled: %s", e)
raise HTTPException(status_code=500, detail=str(e))
+
+
+_CLICKUP_SOURCE_TYPES = {"clickup", "clickupList", "clickupSpace", "clickupFolder"}
+_ALLOWED_RAG_LIMIT_KEYS = {
+ "files": {"maxItems", "maxBytes", "maxFileSize", "maxDepth"},
+ "clickup": {"maxTasks", "maxWorkspaces", "maxListsPerWorkspace"},
+}
+
+
+def _kindForSource(rec: Dict[str, Any], model) -> str:
+ """Map a DataSource record to a RAG-limits kind ('files' or 'clickup').
+
+ FeatureDataSource (tables, not file walkers) reports as 'files' so the
+ same UI/limit shape works; the limits simply won't be consumed by any
+ walker today but are stored for forward-compat.
+ """
+ if model is FeatureDataSource:
+ return "files"
+ sourceType = str(rec.get("sourceType") or "").strip()
+ return "clickup" if sourceType in _CLICKUP_SOURCE_TYPES else "files"
+
+
+def _sanitizeRagLimits(kind: str, raw: Any) -> Dict[str, int]:
+ """Coerce an incoming ragLimits dict to {allowedKey: positive int}.
+
+ Unknown keys are silently dropped; non-positive or non-numeric values
+ are rejected with 400.
+ """
+ if not isinstance(raw, dict):
+ raise HTTPException(status_code=400, detail="ragLimits must be an object")
+ allowed = _ALLOWED_RAG_LIMIT_KEYS.get(kind, set())
+ cleaned: Dict[str, int] = {}
+ for key, value in raw.items():
+ if key not in allowed:
+ continue
+ try:
+ intValue = int(value)
+ except (TypeError, ValueError):
+ raise HTTPException(status_code=400, detail=f"ragLimits.{key} must be an integer")
+ if intValue <= 0:
+ raise HTTPException(status_code=400, detail=f"ragLimits.{key} must be > 0")
+ cleaned[key] = intValue
+ return cleaned
+
+
+@router.patch("/{sourceId}/settings")
+@limiter.limit("30/minute")
+def _updateDataSourceSettings(
+ request: Request,
+ sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"),
+ settings: Dict[str, Any] = Body(..., embed=True),
+ context: RequestContext = Depends(getRequestContext),
+) -> Dict[str, Any]:
+ """Replace `settings` on a DataSource or FeatureDataSource (partial merge per top-level key).
+
+ Currently supports `ragLimits` only. Unknown top-level keys in the body are
+ rejected to avoid silently storing garbage that no consumer reads.
+
+ Owner-only for personal DataSources; mandate/feature scopes additionally
+ accept the mandate or workspace admins of that scope.
+ """
+ if not isinstance(settings, dict):
+ raise HTTPException(status_code=400, detail="settings must be an object")
+ unknown = set(settings.keys()) - {"ragLimits"}
+ if unknown:
+ raise HTTPException(status_code=400, detail=f"Unknown settings keys: {sorted(unknown)}")
+
+ try:
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ rootIf = getRootInterface()
+ rec, model = _findSourceRecord(rootIf.db, sourceId)
+ if not rec:
+ raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
+
+ ownerId = str(rec.get("userId") or "")
+ currentUserId = str(context.user.id)
+ if ownerId and ownerId != currentUserId and not context.isSysAdmin:
+ scope = str(rec.get("scope") or "personal")
+ isMandateAdmin = getattr(context, "isMandateAdmin", False)
+ if scope == "personal" or not isMandateAdmin:
+ raise HTTPException(status_code=403, detail="Not allowed to modify this DataSource's settings")
+
+ kind = _kindForSource(rec, model)
+
+ currentSettings = rec.get("settings") or {}
+ if not isinstance(currentSettings, dict):
+ currentSettings = {}
+ newSettings = dict(currentSettings)
+
+ if "ragLimits" in settings:
+ cleanedLimits = _sanitizeRagLimits(kind, settings["ragLimits"])
+ mergedLimits = dict(currentSettings.get("ragLimits") or {})
+ mergedLimits.update(cleanedLimits)
+ newSettings["ragLimits"] = mergedLimits
+
+ rootIf.db.recordModify(model, sourceId, {"settings": newSettings})
+
+ import json
+ from modules.shared.auditLogger import audit_logger
+ from modules.datamodels.datamodelAudit import AuditCategory
+ audit_logger.logEvent(
+ userId=currentUserId,
+ mandateId=context.mandateId,
+ category=AuditCategory.PERMISSION.value,
+ action="datasource_settings_changed",
+ details=json.dumps({
+ "sourceId": sourceId,
+ "model": model.__name__,
+ "oldSettings": currentSettings,
+ "newSettings": newSettings,
+ }),
+ )
+ logger.info("Updated settings on %s %s by user %s", model.__name__, sourceId, currentUserId)
+ return {"sourceId": sourceId, "settings": newSettings, "updated": True}
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error("Error updating datasource settings: %s", e, exc_info=True)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.get("/{sourceId}/cost-estimate")
+@limiter.limit("60/minute")
+def _getDataSourceCostEstimate(
+ request: Request,
+ sourceId: str = Path(..., description="ID of the DataSource or FeatureDataSource"),
+ context: RequestContext = Depends(getRequestContext),
+) -> Dict[str, Any]:
+ """Return an indicative full-sync cost estimate for the given DataSource.
+
+ Uses the current effective ragLimits (DataSource.settings.ragLimits with
+ fallback to centralized defaults) as the basis. Returns the same
+ `{estimatedTokens, estimatedUsd, basis}` shape regardless of source kind.
+ """
+ try:
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.serviceCenter.services.serviceKnowledge import _ragLimits, _costEstimate
+ rootIf = getRootInterface()
+ rec, model = _findSourceRecord(rootIf.db, sourceId)
+ if not rec:
+ raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
+
+ kind = _kindForSource(rec, model)
+ effective = _ragLimits.getRagLimits(rec, kind)
+ estimate = _costEstimate.estimateBootstrapCost(effective, kind=kind)
+ estimate["sourceId"] = sourceId
+ return estimate
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error("Error computing cost estimate: %s", e, exc_info=True)
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/modules/routes/routeJobs.py b/modules/routes/routeJobs.py
index d2124a0b..9cd89d46 100644
--- a/modules/routes/routeJobs.py
+++ b/modules/routes/routeJobs.py
@@ -21,7 +21,7 @@ from modules.serviceCenter.services.serviceBackgroundJobs import (
getJobStatus,
listJobs,
)
-from modules.shared.i18nRegistry import apiRouteContext
+from modules.shared.i18nRegistry import apiRouteContext, resolveJobMessage
logger = logging.getLogger(__name__)
routeApiMsg = apiRouteContext("routeJobs")
@@ -34,8 +34,20 @@ router = APIRouter(
def _serialiseJob(job: Dict[str, Any]) -> Dict[str, Any]:
- """Strip system audit fields and ensure JSON-safe types."""
- return {k: v for k, v in job.items() if not k.startswith("sys")}
+ """Strip system audit fields, ensure JSON-safe types, translate progress.
+
+ Walkers store progress as a structured payload (``progressMessageData =
+ {key, params}``). The frontend never calls ``t()`` on backend-supplied
+ keys (i18n convention #2), so we resolve the payload here using the
+ request-context language and overwrite ``progressMessage`` with the
+ fully rendered string. Older clients keep working because they read
+ the same field.
+ """
+ out = {k: v for k, v in job.items() if not k.startswith("sys")}
+ translated = resolveJobMessage(out.get("progressMessageData"))
+ if translated:
+ out["progressMessage"] = translated
+ return out
def _userHasMandateAccess(context: RequestContext, mandateId: Optional[str]) -> bool:
diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py
index 7c426d77..99d5c4df 100644
--- a/modules/routes/routeRagInventory.py
+++ b/modules/routes/routeRagInventory.py
@@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Depends, Request
from modules.auth import limiter, getCurrentUser, getRequestContext, RequestContext
from modules.datamodels.datamodelUam import User
-from modules.shared.i18nRegistry import apiRouteContext
+from modules.shared.i18nRegistry import apiRouteContext, resolveJobMessage
routeApiMsg = apiRouteContext("routeRagInventory")
logger = logging.getLogger(__name__)
@@ -24,6 +24,53 @@ router = APIRouter(
)
+_SUB_RESULT_KEYS = ("sharepoint", "outlook", "drive", "gmail", "clickup", "kdrive")
+
+
+def _flattenJobResult(result: Dict[str, Any]) -> Dict[str, Any]:
+ """Bootstrap handlers nest per-service results (e.g. msft returns
+ `{"sharepoint": {...}, "outlook": {...}}`). The UI needs per-connection
+ aggregates AND the first hit limit, so we sum the counters and pick the
+ most informative `stoppedAtLimit` across sub-services.
+
+ Returns a flat dict with the same keys the UI expects on `lastSuccess`.
+ """
+ subResults = [result[k] for k in _SUB_RESULT_KEYS if isinstance(result.get(k), dict)]
+ if not subResults:
+ # Single-service handler that returns flat dict directly (legacy path).
+ return result
+
+ indexed = sum(int(r.get("indexed") or 0) for r in subResults)
+ skippedDup = sum(int(r.get("skippedDuplicate") or 0) for r in subResults)
+ skippedPol = sum(int(r.get("skippedPolicy") or 0) for r in subResults)
+ failed = sum(int(r.get("failed") or 0) for r in subResults)
+ bytes_ = sum(int(r.get("bytesProcessed") or 0) for r in subResults)
+ # Parallel sub-services: wall-clock ≈ slowest one.
+ durationMs = max((int(r.get("durationMs") or 0) for r in subResults), default=0)
+
+ # First sub-service that hit a limit wins — UI shows one banner per
+ # connection; if multiple stopped, the first one is informative enough
+ # and the user re-runs after raising that budget.
+ stoppedAtLimit: Optional[str] = None
+ limits: Dict[str, Any] = {}
+ for r in subResults:
+ if r.get("stoppedAtLimit"):
+ stoppedAtLimit = r["stoppedAtLimit"]
+ limits = r.get("limits") or {}
+ break
+
+ return {
+ "indexed": indexed,
+ "skippedDuplicate": skippedDup,
+ "skippedPolicy": skippedPol,
+ "failed": failed,
+ "bytesProcessed": bytes_,
+ "durationMs": durationMs,
+ "stoppedAtLimit": stoppedAtLimit,
+ "limits": limits,
+ }
+
+
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
"""Build per-connection RAG inventory rows.
@@ -111,7 +158,17 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=50)
connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId]
runningJobs = [
- {"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")}
+ {
+ "jobId": j["id"],
+ "progress": j.get("progress", 0),
+ # Server-side translate the structured walker payload into
+ # the request-context language; frontend renders 1:1 (no
+ # `t()` on backend-supplied keys).
+ "progressMessage": (
+ resolveJobMessage(j.get("progressMessageData"))
+ or j.get("progressMessage", "")
+ ),
+ }
for j in connJobs
if j.get("status") in ("PENDING", "RUNNING")
]
@@ -126,7 +183,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"finishedAt": j.get("finishedAt"),
}
elif status == "SUCCESS" and lastSuccess is None:
- result = j.get("result") or {}
+ # Bootstrap handlers may return either a flat dict (single
+ # service) or a nested dict keyed by sub-service (e.g. msft
+ # returns {"sharepoint": {...}, "outlook": {...}}). Flatten
+ # so the UI always sees aggregated counters and the first
+ # sub-service that hit a limit.
+ result = _flattenJobResult(j.get("result") or {})
lastSuccess = {
"jobId": j["id"],
"finishedAt": j.get("finishedAt"),
@@ -337,7 +399,10 @@ def _getActiveJobs(
"connectionLabel": getattr(conn, "displayLabel", None) or getattr(conn, "authority", connId),
"jobType": j.get("jobType", "connection.bootstrap"),
"progress": j.get("progress", 0),
- "progressMessage": j.get("progressMessage", ""),
+ "progressMessage": (
+ resolveJobMessage(j.get("progressMessageData"))
+ or j.get("progressMessage", "")
+ ),
})
return active
except Exception as e:
diff --git a/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py b/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py
index e27dae58..90b69bce 100644
--- a/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py
+++ b/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py
@@ -54,19 +54,53 @@ _CANCEL_CHECK_INTERVAL_S = 3.0
class JobProgressCallback:
- """Callable progress reporter with cooperative cancel-check for long-running walkers."""
+ """Callable progress reporter with cooperative cancel-check for long-running walkers.
+
+ Two ways to set a progress message:
+ progressCb(50, "145 Dateien verarbeitet") # legacy plaintext (DE)
+ progressCb(50, messageKey="{n} Dateien verarbeitet",
+ messageParams={"n": 145}) # i18n-friendly
+
+ When `messageKey` is given the structured payload is written to
+ `BackgroundJob.progressMessageData` so the frontend can render it via
+ `t(key, params)` in the user's UI language. A best-effort rendered
+ fallback is also stored in `progressMessage` for older clients, logs,
+ and audit trails.
+ """
def __init__(self, jobId: str):
self._jobId = jobId
self._cancelledCache: Optional[bool] = None
self._lastCheckedAt: float = 0.0
- def __call__(self, progress: int, message: Optional[str] = None) -> None:
+ def __call__(
+ self,
+ progress: int,
+ message: Optional[str] = None,
+ *,
+ messageKey: Optional[str] = None,
+ messageParams: Optional[Dict[str, Any]] = None,
+ ) -> None:
try:
clamped = max(0, min(100, int(progress)))
fields: Dict[str, Any] = {"progress": clamped}
- if message is not None:
+
+ if messageKey is not None:
+ params = messageParams or {}
+ try:
+ fallback = messageKey.format(**params)
+ except (KeyError, IndexError, ValueError) as fmtErr:
+ fallback = message or messageKey
+ logger.warning(
+ "progressCb message format failed for job %s key=%r params=%r: %s",
+ self._jobId, messageKey, params, fmtErr,
+ )
+ fields["progressMessageData"] = {"key": messageKey, "params": params}
+ fields["progressMessage"] = (message or fallback)[:500]
+ elif message is not None:
fields["progressMessage"] = message[:500]
+ fields["progressMessageData"] = None
+
_updateJob(self._jobId, fields)
except Exception as ex:
logger.warning("Progress update failed for job %s: %s", self._jobId, ex)
diff --git a/modules/serviceCenter/services/serviceChat/mainServiceChat.py b/modules/serviceCenter/services/serviceChat/mainServiceChat.py
index 2ca61d7e..61026de0 100644
--- a/modules/serviceCenter/services/serviceChat/mainServiceChat.py
+++ b/modules/serviceCenter/services/serviceChat/mainServiceChat.py
@@ -534,11 +534,17 @@ class ChatService:
) -> Dict[str, Any]:
"""Create a new external data source reference.
- Returns existing record if connectionId + path already exists (upsert semantics).
+ Upsert key is `(connectionId, sourceType, path)`. The same `path='/'`
+ can carry multiple DataSources discriminated by sourceType: the
+ Connection-Root (sourceType=, e.g. 'msft') plus one per
+ service (sourceType='sharepointFolder', 'outlookFolder', ...). The
+ sourceType filter MUST be present, otherwise a Service-Root POST
+ returns the Connection-Root and toggles cascade onto every sibling.
"""
from modules.datamodels.datamodelDataSource import DataSource
existing = self.interfaceDbApp.db.getRecordset(
- DataSource, recordFilter={"connectionId": connectionId, "path": path}
+ DataSource,
+ recordFilter={"connectionId": connectionId, "sourceType": sourceType, "path": path},
)
if existing:
return existing[0] if isinstance(existing[0], dict) else existing[0].model_dump()
diff --git a/modules/serviceCenter/services/serviceKnowledge/_costEstimate.py b/modules/serviceCenter/services/serviceKnowledge/_costEstimate.py
new file mode 100644
index 00000000..565c219d
--- /dev/null
+++ b/modules/serviceCenter/services/serviceKnowledge/_costEstimate.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""Indicative cost estimation for a RAG bootstrap run.
+
+This is **not** a billing-grade forecast: it gives the user a back-of-the-envelope
+USD figure for the worst-case full sync, so they can sanity-check before raising
+`maxBytes`/`maxItems`. The output always carries the underlying assumptions
+(`basis`) so the user can judge plausibility.
+
+Heuristic:
+ estimatedTokens = ceil(maxBytes / CHARS_PER_TOKEN_BYTES_FACTOR)
+ estimatedUsd = estimatedTokens / 1_000_000 * EMBEDDING_USD_PER_MTOKEN
+
+Defaults match OpenAI `text-embedding-3-small` pricing (2026-Q2).
+"""
+
+from __future__ import annotations
+
+import math
+from typing import Any, Dict
+
+
+CHARS_PER_TOKEN = 4
+EMBEDDING_USD_PER_MTOKEN = 0.02
+DEFAULT_TOKENS_PER_ITEM = 1500
+BYTES_PER_TOKEN_TEXT_FACTOR = 4
+EXTRACTABLE_FRACTION = 0.4
+
+
+def estimateBootstrapCost(limits: Dict[str, int], kind: str = "files") -> Dict[str, Any]:
+ """Return an indicative cost estimate dict for a DataSource bootstrap.
+
+ Returned shape::
+
+ {
+ "estimatedTokens": int,
+ "estimatedUsd": float, # rounded to 4 decimals
+ "basis": {
+ "kind": "files"|"clickup",
+ "limits": {...},
+ "assumptions": {
+ "embeddingUsdPerMToken": 0.02,
+ "charsPerToken": 4,
+ "extractableFraction": 0.4,
+ "tokensPerItem": 1500 # only for clickup-like item counts
+ },
+ "notes": "non-binding, depends on real file content..."
+ }
+ }
+ """
+ assumptions: Dict[str, Any] = {
+ "embeddingUsdPerMToken": EMBEDDING_USD_PER_MTOKEN,
+ "charsPerToken": CHARS_PER_TOKEN,
+ }
+
+ if kind == "files":
+ maxBytes = int(limits.get("maxBytes") or 0)
+ extractableBytes = maxBytes * EXTRACTABLE_FRACTION
+ estimatedTokens = int(math.ceil(extractableBytes / BYTES_PER_TOKEN_TEXT_FACTOR))
+ assumptions["extractableFraction"] = EXTRACTABLE_FRACTION
+ assumptions["formula"] = "ceil(maxBytes * 0.4 / 4)"
+ elif kind == "clickup":
+ maxTasks = int(limits.get("maxTasks") or 0)
+ maxWorkspaces = max(1, int(limits.get("maxWorkspaces") or 1))
+ estimatedTokens = maxTasks * maxWorkspaces * DEFAULT_TOKENS_PER_ITEM
+ assumptions["tokensPerItem"] = DEFAULT_TOKENS_PER_ITEM
+ assumptions["formula"] = "maxTasks * maxWorkspaces * 1500"
+ else:
+ estimatedTokens = 0
+ assumptions["formula"] = "unknown kind, returning zero"
+
+ estimatedUsd = round(estimatedTokens / 1_000_000 * EMBEDDING_USD_PER_MTOKEN, 4)
+
+ return {
+ "estimatedTokens": estimatedTokens,
+ "estimatedUsd": estimatedUsd,
+ "basis": {
+ "kind": kind,
+ "limits": dict(limits),
+ "assumptions": assumptions,
+ "notes": (
+ "Indicative only. Actual cost depends on file types, extractable text "
+ "ratio, dedup hit-rate, retries, and current embedding model pricing."
+ ),
+ },
+ }
diff --git a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py
new file mode 100644
index 00000000..00180c9f
--- /dev/null
+++ b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py
@@ -0,0 +1,342 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled, scope).
+
+Three-state flags allow tree elements to either set an explicit value or
+inherit the value from their nearest ancestor in the path hierarchy. The
+walker (RAG/Neutralize) and routes resolve the *effective* value; the cascade
+helper resets explicit descendant values when a parent is toggled.
+
+Path-traversal rules:
+- A DataSource is identified by `(connectionId, sourceType, path)`.
+- The root of a service tree is `path == '/'`.
+- Sub-elements have paths like `/folder1/sub`. Their parent path is the
+ longest prefix path that exists as a DataSource record (string-based).
+- If no ancestor with an explicit value exists, the default is `False`
+ (or `'personal'` for scope) — matching the legacy behavior of NULL = inherit.
+"""
+
+import logging
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+logger = logging.getLogger(__name__)
+
+_INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope")
+
+# Connection-root DataSources carry the authority as their sourceType
+# (e.g. 'msft', 'google'). They sit one level above all service DataSources
+# of the same connection in the visual tree, so flag inheritance must
+# cross sourceType boundaries — but ONLY from these authority roots.
+_AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"})
+
+
+def _normalisePath(path: Optional[str]) -> str:
+ """Normalize a DataSource path to '/'-prefixed, no trailing slash (except root)."""
+ if not path:
+ return "/"
+ p = str(path).strip()
+ if not p.startswith("/"):
+ p = "/" + p
+ if len(p) > 1 and p.endswith("/"):
+ p = p.rstrip("/")
+ return p
+
+
+def _flagDefault(flag: str) -> Any:
+ if flag == "scope":
+ return "personal"
+ return False
+
+
+def _isExplicit(value: Any) -> bool:
+ """A flag value is explicit when it is not None.
+
+ Note: legacy rows may carry empty-string scope; treat as inherit too.
+ """
+ if value is None:
+ return False
+ if isinstance(value, str) and value == "":
+ return False
+ return True
+
+
+def _getRecordValue(rec: Any, key: str) -> Any:
+ if isinstance(rec, dict):
+ return rec.get(key)
+ return getattr(rec, key, None)
+
+
+def _findAncestorChain(
+ rec: Dict[str, Any],
+ allDs: Iterable[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Return all ancestor DataSources of `rec` in the same connection,
+ ordered nearest-first.
+
+ Two ancestor relations are merged:
+ 1) **same-sourceType path-ancestor** — strict path-prefix within the
+ same service tree (sharepointFolder, gmailFolder, ...).
+ 2) **connection-root ancestor** — a DS with `path='/'` and
+ `sourceType` ∈ authority set (msft, google, ...) is the parent of
+ every other DS in that connection regardless of sourceType, so a
+ toggle on the connection node propagates to all services beneath.
+
+ The connection-root is always the most distant ancestor and therefore
+ sorts after any same-sourceType ancestors.
+ """
+ recPath = _normalisePath(_getRecordValue(rec, "path"))
+ recSourceType = _getRecordValue(rec, "sourceType")
+ recConnectionId = _getRecordValue(rec, "connectionId")
+ sameTypeCandidates: List[Tuple[int, Dict[str, Any]]] = []
+ connectionRoot: Optional[Dict[str, Any]] = None
+ recIsConnectionRoot = recSourceType in _AUTHORITY_SOURCE_TYPES and recPath == "/"
+ for cand in allDs:
+ if _getRecordValue(cand, "id") == _getRecordValue(rec, "id"):
+ continue
+ if _getRecordValue(cand, "connectionId") != recConnectionId:
+ continue
+ candSourceType = _getRecordValue(cand, "sourceType")
+ candPath = _normalisePath(_getRecordValue(cand, "path"))
+ if candSourceType == recSourceType:
+ if candPath == recPath or not _isAncestorPath(candPath, recPath):
+ continue
+ sameTypeCandidates.append((len(candPath), cand))
+ elif (
+ not recIsConnectionRoot
+ and candSourceType in _AUTHORITY_SOURCE_TYPES
+ and candPath == "/"
+ ):
+ connectionRoot = cand
+ sameTypeCandidates.sort(key=lambda x: x[0], reverse=True)
+ chain = [c for _, c in sameTypeCandidates]
+ if connectionRoot is not None:
+ chain.append(connectionRoot)
+ return chain
+
+
+def _isAncestorPath(ancestor: str, descendant: str) -> bool:
+ """True iff `ancestor` is a strict path-prefix of `descendant`.
+
+ '/' is ancestor of every non-root path. For non-root prefixes, the
+ descendant must continue with '/' so '/foo' isn't treated as ancestor of
+ '/foobar'.
+ """
+ if ancestor == descendant:
+ return False
+ if ancestor == "/":
+ return descendant != "/"
+ return descendant.startswith(ancestor + "/")
+
+
+def getEffectiveFlag(
+ rec: Dict[str, Any],
+ flag: str,
+ sameConnectionDs: Iterable[Dict[str, Any]],
+) -> Any:
+ """Resolve the effective value of a flag via path-traversal.
+
+ Order: own value (if explicit) → nearest ancestor with explicit value →
+ static default (`False` or `'personal'`).
+ """
+ if flag not in _INHERITABLE_FLAGS:
+ raise ValueError(f"Unknown inheritable flag: {flag}")
+ own = _getRecordValue(rec, flag)
+ if _isExplicit(own):
+ return own
+ chain = _findAncestorChain(rec, sameConnectionDs)
+ for ancestor in chain:
+ ancestorVal = _getRecordValue(ancestor, flag)
+ if _isExplicit(ancestorVal):
+ return ancestorVal
+ return _flagDefault(flag)
+
+
+def cascadeResetDescendants(
+ rootIf: Any,
+ parentRec: Dict[str, Any],
+ flag: str,
+) -> int:
+ """Reset all explicit descendant values of `flag` to NULL (= inherit).
+
+ Descendant relation mirrors `_findAncestorChain`:
+ - Connection-root (`path='/'` AND `sourceType` ∈ authorities) is parent
+ of every other DS in that connection (cross-sourceType cascade).
+ - Otherwise: same-sourceType strict path-descendants only.
+
+ Only the targeted `flag` is reset; other flags on the descendant are
+ untouched.
+
+ Returns the number of records updated.
+ """
+ if flag not in _INHERITABLE_FLAGS:
+ raise ValueError(f"Unknown inheritable flag: {flag}")
+ from modules.datamodels.datamodelDataSource import DataSource
+
+ connectionId = _getRecordValue(parentRec, "connectionId")
+ parentSourceType = _getRecordValue(parentRec, "sourceType")
+ parentPath = _normalisePath(_getRecordValue(parentRec, "path"))
+ parentId = _getRecordValue(parentRec, "id")
+ if not connectionId or not parentSourceType:
+ return 0
+
+ parentIsConnectionRoot = (
+ parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/"
+ )
+
+ siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
+ affected = 0
+ for sib in siblings:
+ sibId = _getRecordValue(sib, "id")
+ if sibId == parentId:
+ continue
+ sibSourceType = _getRecordValue(sib, "sourceType")
+ sibPath = _normalisePath(_getRecordValue(sib, "path"))
+ if parentIsConnectionRoot:
+ # Connection-root resets everything else under this connection.
+ pass
+ else:
+ if sibSourceType != parentSourceType:
+ continue
+ if not _isAncestorPath(parentPath, sibPath):
+ continue
+ sibVal = _getRecordValue(sib, flag)
+ if not _isExplicit(sibVal):
+ continue
+ try:
+ rootIf.db.recordModify(DataSource, sibId, {flag: None})
+ affected += 1
+ except Exception as exc:
+ logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc)
+ if affected:
+ logger.info(
+ "Cascade-reset %s on %d descendants of DataSource (connectionId=%s, sourceType=%s, path=%s, connectionRoot=%s)",
+ flag, affected, connectionId, parentSourceType, parentPath, parentIsConnectionRoot,
+ )
+ return affected
+
+
+def _fdsClassify(fds: Dict[str, Any]) -> str:
+ """Return 'workspace' | 'table' | 'record' based on the FDS identifier shape."""
+ tableName = _getRecordValue(fds, "tableName") or ""
+ recordFilter = _getRecordValue(fds, "recordFilter")
+ if tableName == "*":
+ return "workspace"
+ if not recordFilter:
+ return "table"
+ return "record"
+
+
+def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool:
+ """Return True iff `parent` FDS is a strict ancestor of `child` FDS.
+
+ Hierarchy within one `workspaceInstanceId`:
+ workspace-wildcard (tableName='*') → table-wildcard (tableName='X', !recordFilter)
+ → record-fds (tableName='X', recordFilter.id=...)
+ table-wildcard (tableName='X') → record-fds (tableName='X', recordFilter.id=...)
+ """
+ parentWsId = _getRecordValue(parent, "workspaceInstanceId")
+ childWsId = _getRecordValue(child, "workspaceInstanceId")
+ if not parentWsId or parentWsId != childWsId:
+ return False
+ if _getRecordValue(parent, "id") == _getRecordValue(child, "id"):
+ return False
+ parentKind = _fdsClassify(parent)
+ childKind = _fdsClassify(child)
+ if parentKind == "workspace":
+ return childKind in ("table", "record")
+ if parentKind == "table":
+ if childKind != "record":
+ return False
+ return _getRecordValue(parent, "tableName") == _getRecordValue(child, "tableName")
+ return False
+
+
+def getEffectiveFlagFds(
+ rec: Dict[str, Any],
+ flag: str,
+ sameWorkspaceFds: Iterable[Dict[str, Any]],
+) -> Any:
+ """Resolve effective value of a FeatureDataSource flag.
+
+ Order: own (if explicit) → table-wildcard (if explicit) →
+ workspace-wildcard (if explicit) → static default.
+ """
+ if flag not in ("neutralize", "scope"):
+ raise ValueError(f"Unknown inheritable FDS flag: {flag}")
+ own = _getRecordValue(rec, flag)
+ if _isExplicit(own):
+ return own
+ workspaceFds: List[Dict[str, Any]] = list(sameWorkspaceFds)
+ ancestors = [a for a in workspaceFds if _fdsIsAncestor(a, rec)]
+ ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1)
+ for ancestor in ancestors:
+ val = _getRecordValue(ancestor, flag)
+ if _isExplicit(val):
+ return val
+ return _flagDefault(flag)
+
+
+def cascadeResetDescendantsFds(
+ rootIf: Any,
+ parentRec: Dict[str, Any],
+ flag: str,
+) -> int:
+ """Reset explicit `flag` to NULL on every descendant FDS of `parentRec`.
+
+ Only the targeted flag is reset; other flags on descendants are untouched.
+ Returns the number of records updated.
+ """
+ if flag not in ("neutralize", "scope"):
+ raise ValueError(f"Unknown inheritable FDS flag: {flag}")
+ from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
+
+ workspaceInstanceId = _getRecordValue(parentRec, "workspaceInstanceId")
+ if not workspaceInstanceId:
+ return 0
+ siblings = rootIf.db.getRecordset(
+ FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId}
+ )
+ affected = 0
+ for sib in siblings:
+ if not _fdsIsAncestor(parentRec, sib):
+ continue
+ sibVal = _getRecordValue(sib, flag)
+ if not _isExplicit(sibVal):
+ continue
+ sibId = _getRecordValue(sib, "id")
+ try:
+ rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None})
+ affected += 1
+ except Exception as exc:
+ logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc)
+ if affected:
+ logger.info(
+ "FDS cascade-reset %s on %d descendants of FDS (workspaceInstanceId=%s, kind=%s)",
+ flag, affected, workspaceInstanceId, _fdsClassify(parentRec),
+ )
+ return affected
+
+
+def buildEffectiveByConnection(
+ dataSources: Iterable[Dict[str, Any]],
+ flag: str,
+) -> Dict[str, Any]:
+ """Pre-compute the effective value of `flag` for every DataSource id.
+
+ Useful for batch operations (walker, route DTOs) that touch many records
+ at once. O(N²) in the worst case but N is bounded per connection.
+ """
+ if flag not in _INHERITABLE_FLAGS:
+ raise ValueError(f"Unknown inheritable flag: {flag}")
+ bySourceType: Dict[Tuple[str, str], List[Dict[str, Any]]] = {}
+ for ds in dataSources:
+ connId = _getRecordValue(ds, "connectionId") or ""
+ srcType = _getRecordValue(ds, "sourceType") or ""
+ bySourceType.setdefault((connId, srcType), []).append(ds)
+
+ out: Dict[str, Any] = {}
+ for group in bySourceType.values():
+ for rec in group:
+ recId = _getRecordValue(rec, "id")
+ out[recId] = getEffectiveFlag(rec, flag, group)
+ return out
diff --git a/modules/serviceCenter/services/serviceKnowledge/_progressMessages.py b/modules/serviceCenter/services/serviceKnowledge/_progressMessages.py
new file mode 100644
index 00000000..99d91d6b
--- /dev/null
+++ b/modules/serviceCenter/services/serviceKnowledge/_progressMessages.py
@@ -0,0 +1,23 @@
+"""Central i18n registration for BackgroundJob progress messages.
+
+Walkers and consumers report progress via ``progressCb(..., messageKey="…",
+messageParams={...})``. Those keys are not seen by ``t()`` at call time, so
+without a stub registration they would never make it into the boot-time
+``UiLanguageSet(xx)`` sync. Importing this module is enough to register
+every known key — call sites stay clean while translators can still find
+the texts in the standard i18n table.
+
+Keep this list in lockstep with the ``messageKey=`` arguments used in
+``subConnectorSync*.py`` and ``subConnectorIngestConsumer.py``.
+"""
+
+from modules.shared.i18nRegistry import t
+
+# Bootstrap walkers (one per connector family)
+t("{n} Dateien verarbeitet, {indexed} indexiert")
+t("{n} Tasks verarbeitet, {indexed} indexiert")
+t("{n} Mails verarbeitet, {indexed} indexiert")
+
+# Ingestion consumer hand-offs
+t("Verbindung wird aufgebaut ({authority})")
+t("Synchronisierung läuft...")
diff --git a/modules/serviceCenter/services/serviceKnowledge/_ragLimits.py b/modules/serviceCenter/services/serviceKnowledge/_ragLimits.py
new file mode 100644
index 00000000..de0a4886
--- /dev/null
+++ b/modules/serviceCenter/services/serviceKnowledge/_ragLimits.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""Centralized RAG bootstrap limits + DataSource-scoped resolution.
+
+The original walkers (SharePoint, kDrive, gDrive, ClickUp) each carried their
+own module-level `MAX_*_DEFAULT` constants and silently stopped indexing once
+they were exceeded. That made it impossible for a user with a 500 MB folder to
+override the 200 MB cap without a code change.
+
+This module is the single source of truth for two things:
+
+1. The canonical default budget per source kind (`FILES_LIMITS_DEFAULT`,
+ `CLICKUP_LIMITS_DEFAULT`). Walkers fall back to these when a DataSource has
+ no `settings.ragLimits` yet.
+
+2. The pure read/lazy-fill helpers that walkers and the API use to merge a
+ DataSource's stored settings with the defaults. No override layers, no
+ resolver chain: what is in `DataSource.settings.ragLimits` is what the
+ walker uses.
+
+Lazy fill: the first time a DataSource is processed, the defaults are written
+to its `settings.ragLimits` so the UI shows real values immediately, even if
+the user has never opened the settings modal.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, Optional
+
+
+logger = logging.getLogger(__name__)
+
+
+FILES_LIMITS_DEFAULT: Dict[str, int] = {
+ "maxItems": 500,
+ "maxBytes": 200 * 1024 * 1024,
+ "maxFileSize": 25 * 1024 * 1024,
+ "maxDepth": 4,
+}
+
+
+CLICKUP_LIMITS_DEFAULT: Dict[str, int] = {
+ "maxTasks": 500,
+ "maxWorkspaces": 3,
+ "maxListsPerWorkspace": 20,
+}
+
+
+_LIMITS_BY_KIND: Dict[str, Dict[str, int]] = {
+ "files": FILES_LIMITS_DEFAULT,
+ "clickup": CLICKUP_LIMITS_DEFAULT,
+}
+
+
+def getDefaults(kind: str) -> Dict[str, int]:
+ """Return a fresh copy of the default budget for the given walker kind.
+
+ `kind` is either "files" (Sharepoint, kDrive, gDrive) or "clickup".
+ Returning a copy lets callers mutate the result safely.
+ """
+ defaults = _LIMITS_BY_KIND.get(kind)
+ if defaults is None:
+ raise ValueError(f"Unknown RAG limit kind: {kind!r}")
+ return dict(defaults)
+
+
+def getStoredOverrides(dataSource: Optional[Dict[str, Any]], kind: str) -> Dict[str, int]:
+ """Return ONLY the limits explicitly set on `dataSource.settings.ragLimits`.
+
+ Missing keys are NOT filled with defaults — that is the caller's job (so
+ a programmatically supplied `limits=` from a Caller still wins when the
+ DataSource has no override). Pure read, no DB writes.
+ """
+ if not isinstance(dataSource, dict):
+ return {}
+ settings = dataSource.get("settings") or {}
+ if not isinstance(settings, dict):
+ return {}
+ stored = settings.get("ragLimits")
+ if not isinstance(stored, dict):
+ return {}
+ allowed = set(_LIMITS_BY_KIND.get(kind, {}).keys())
+ out: Dict[str, int] = {}
+ for key, raw in stored.items():
+ if key not in allowed or raw is None:
+ continue
+ try:
+ out[key] = int(raw)
+ except (TypeError, ValueError):
+ logger.warning(
+ "Ignoring non-int ragLimits[%s]=%r on DataSource %s",
+ key, raw, dataSource.get("id"),
+ )
+ return out
+
+
+def getRagLimits(dataSource: Optional[Dict[str, Any]], kind: str) -> Dict[str, int]:
+ """Effective RAG limits for the API/cost-estimate use-case.
+
+ Stored overrides win over `getDefaults(kind)`. Walkers should NOT use this
+ function — they should pass their own caller-limits as the fallback so that
+ a runtime-supplied `limits=` parameter is honoured (see `getStoredOverrides`).
+ """
+ base = getDefaults(kind)
+ base.update(getStoredOverrides(dataSource, kind))
+ return base
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py
index c86aed86..618a9965 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py
@@ -141,18 +141,39 @@ _SOURCE_TYPE_MAP = {
def _loadRagEnabledDataSources(connectionId: str, dataSourceIds: Optional[list] = None):
- """Load DataSource rows with ragIndexEnabled=true for a connection.
+ """Load DataSource rows whose *effective* ragIndexEnabled is True.
- If dataSourceIds is provided (mini-bootstrap), filter to only those IDs.
+ Cascade-inherit semantics: a DataSource with `ragIndexEnabled=None`
+ follows its nearest ancestor's value (path-traversal). Walker iterates
+ over all DataSources whose effective value resolves to True, including
+ inherited ones.
+
+ Returned dicts carry **resolved** flags (`neutralize`, `scope`) so the
+ downstream walkers can keep reading `ds.get("neutralize")` directly
+ without having to know about the inheritance chain.
+
+ If `dataSourceIds` is provided (mini-bootstrap), the explicit set is
+ intersected with the effective-true set.
"""
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelDataSource import DataSource
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
rootIf = getRootInterface()
allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
+ resolved = []
+ for ds in allDs:
+ effRagIndex = getEffectiveFlag(ds, "ragIndexEnabled", allDs)
+ if effRagIndex is not True:
+ continue
+ dsCopy = dict(ds) if isinstance(ds, dict) else {**ds.__dict__}
+ dsCopy["neutralize"] = getEffectiveFlag(ds, "neutralize", allDs)
+ dsCopy["scope"] = getEffectiveFlag(ds, "scope", allDs)
+ dsCopy["ragIndexEnabled"] = True
+ resolved.append(dsCopy)
if dataSourceIds:
- return [ds for ds in allDs if ds.get("id") in dataSourceIds and ds.get("ragIndexEnabled")]
- return [ds for ds in allDs if ds.get("ragIndexEnabled")]
+ resolved = [ds for ds in resolved if ds.get("id") in dataSourceIds]
+ return resolved
async def _bootstrapJobHandler(
@@ -167,7 +188,11 @@ async def _bootstrapJobHandler(
if not connectionId:
raise ValueError("connection.bootstrap requires payload.connectionId")
- progressCb(5, f"resolving {authority} connection")
+ progressCb(
+ 5,
+ messageKey="Verbindung wird aufgebaut ({authority})",
+ messageParams={"authority": authority},
+ )
# Defensive consent check
try:
@@ -225,7 +250,7 @@ async def _bootstrapJobHandler(
bootstrapOutlook,
)
- progressCb(0, "Synchronisierung läuft...")
+ progressCb(0, messageKey="Synchronisierung läuft...")
spDs = _filterDs("sharepoint")
olDs = _filterDs("outlook")
async def _noopResult():
@@ -251,7 +276,7 @@ async def _bootstrapJobHandler(
bootstrapGmail,
)
- progressCb(0, "Synchronisierung läuft...")
+ progressCb(0, messageKey="Synchronisierung läuft...")
gdDs = _filterDs("drive")
gmDs = _filterDs("gmail")
async def _noopResult():
@@ -274,7 +299,7 @@ async def _bootstrapJobHandler(
bootstrapClickup,
)
- progressCb(0, "Synchronisierung läuft...")
+ progressCb(0, messageKey="Synchronisierung läuft...")
cuDs = _filterDs("clickup")
cuResult = await bootstrapClickup(connectionId=connectionId, progressCb=progressCb, dataSources=cuDs) if cuDs else {"skipped": True, "reason": "no_datasources"}
return {
@@ -288,7 +313,7 @@ async def _bootstrapJobHandler(
bootstrapKdrive,
)
- progressCb(0, "Synchronisierung läuft...")
+ progressCb(0, messageKey="Synchronisierung läuft...")
kdDs = _filterDs("kdrive")
kdResult = await bootstrapKdrive(connectionId=connectionId, progressCb=progressCb, dataSources=kdDs) if kdDs else {"skipped": True, "reason": "no_datasources"}
return {
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
index 959e42c9..28c24275 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
@@ -33,13 +33,21 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
logger = logging.getLogger(__name__)
-MAX_TASKS_DEFAULT = 500
-MAX_WORKSPACES_DEFAULT = 3
-MAX_LISTS_PER_WORKSPACE_DEFAULT = 20
+from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper
+
+_CLICKUP_DEFAULTS = _ragLimitsHelper.CLICKUP_LIMITS_DEFAULT
+MAX_TASKS_DEFAULT = _CLICKUP_DEFAULTS["maxTasks"]
+MAX_WORKSPACES_DEFAULT = _CLICKUP_DEFAULTS["maxWorkspaces"]
+MAX_LISTS_PER_WORKSPACE_DEFAULT = _CLICKUP_DEFAULTS["maxListsPerWorkspace"]
MAX_DESCRIPTION_CHARS_DEFAULT = 8000
MAX_AGE_DAYS_DEFAULT = 180
+def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]:
+ """Return explicit RAG-limit overrides stored on the DataSource (or {})."""
+ return _ragLimitsHelper.getStoredOverrides(ds, "clickup")
+
+
@dataclass
class ClickupBootstrapLimits:
maxTasks: int = MAX_TASKS_DEFAULT
@@ -236,10 +244,11 @@ async def bootstrapClickup(
dsId = ds.get("id", "")
dsNeutralize = ds.get("neutralize", False)
+ eff = _resolveDataSourceLimits(dsId, ds)
dsLimits = ClickupBootstrapLimits(
- maxTasks=limits.maxTasks,
- maxWorkspaces=limits.maxWorkspaces,
- maxListsPerWorkspace=limits.maxListsPerWorkspace,
+ maxTasks=eff.get("maxTasks", limits.maxTasks),
+ maxWorkspaces=eff.get("maxWorkspaces", limits.maxWorkspaces),
+ maxListsPerWorkspace=eff.get("maxListsPerWorkspace", limits.maxListsPerWorkspace),
maxDescriptionChars=limits.maxDescriptionChars,
maxAgeDays=limits.maxAgeDays,
includeClosed=limits.includeClosed,
@@ -520,7 +529,11 @@ async def _ingestTask(
if hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return
try:
- progressCb(0, f"{processed} Tasks verarbeitet, {result.indexed} indexiert")
+ progressCb(
+ 0,
+ messageKey="{n} Tasks verarbeitet, {indexed} indexiert",
+ messageParams={"n": processed, "indexed": result.indexed},
+ )
except Exception:
pass
if processed % 50 == 0:
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
index e27abacb..7600cce0 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
@@ -31,13 +31,21 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
logger = logging.getLogger(__name__)
-MAX_ITEMS_DEFAULT = 500
-MAX_BYTES_DEFAULT = 200 * 1024 * 1024
-MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024
+from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper
+
+_FILES_DEFAULTS = _ragLimitsHelper.FILES_LIMITS_DEFAULT
+MAX_ITEMS_DEFAULT = _FILES_DEFAULTS["maxItems"]
+MAX_BYTES_DEFAULT = _FILES_DEFAULTS["maxBytes"]
+MAX_FILE_SIZE_DEFAULT = _FILES_DEFAULTS["maxFileSize"]
+MAX_DEPTH_DEFAULT = _FILES_DEFAULTS["maxDepth"]
SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/")
-MAX_DEPTH_DEFAULT = 4
MAX_AGE_DAYS_DEFAULT = 365
+
+def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]:
+ """Return explicit RAG-limit overrides stored on the DataSource (or {})."""
+ return _ragLimitsHelper.getStoredOverrides(ds, "files")
+
FOLDER_MIME = "application/vnd.google-apps.folder"
@@ -175,12 +183,13 @@ async def bootstrapGdrive(
dsId = ds.get("id", "")
dsNeutralize = ds.get("neutralize", False)
dsMaxAgeDays = ds.get("maxAgeDays", limits.maxAgeDays)
+ eff = _resolveDataSourceLimits(dsId, ds)
dsLimits = GdriveBootstrapLimits(
- maxItems=limits.maxItems,
- maxBytes=limits.maxBytes,
- maxFileSize=limits.maxFileSize,
+ maxItems=eff.get("maxItems", limits.maxItems),
+ maxBytes=eff.get("maxBytes", limits.maxBytes),
+ maxFileSize=eff.get("maxFileSize", limits.maxFileSize),
skipMimePrefixes=limits.skipMimePrefixes,
- maxDepth=limits.maxDepth,
+ maxDepth=eff.get("maxDepth", limits.maxDepth),
maxAgeDays=dsMaxAgeDays,
neutralize=dsNeutralize,
)
@@ -459,7 +468,11 @@ async def _ingestOne(
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
- progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
+ progressCb(
+ 0,
+ messageKey="{n} Dateien verarbeitet, {indexed} indexiert",
+ messageParams={"n": processed, "indexed": result.indexed},
+ )
except Exception:
pass
logger.info(
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py
index 3130e942..96f9cecf 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py
@@ -474,7 +474,11 @@ async def _ingestMessage(
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
- progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert")
+ progressCb(
+ 0,
+ messageKey="{n} Mails verarbeitet, {indexed} indexiert",
+ messageParams={"n": processed, "indexed": result.indexed},
+ )
except Exception:
pass
if processed % 50 == 0:
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
index dcf19e39..f95aafd1 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
@@ -27,11 +27,19 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
logger = logging.getLogger(__name__)
-MAX_ITEMS_DEFAULT = 500
-MAX_BYTES_DEFAULT = 200 * 1024 * 1024
-MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024
+from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper
+
+_FILES_DEFAULTS = _ragLimitsHelper.FILES_LIMITS_DEFAULT
+MAX_ITEMS_DEFAULT = _FILES_DEFAULTS["maxItems"]
+MAX_BYTES_DEFAULT = _FILES_DEFAULTS["maxBytes"]
+MAX_FILE_SIZE_DEFAULT = _FILES_DEFAULTS["maxFileSize"]
+MAX_DEPTH_DEFAULT = _FILES_DEFAULTS["maxDepth"]
SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/")
-MAX_DEPTH_DEFAULT = 4
+
+
+def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]:
+ """Return explicit RAG-limit overrides stored on the DataSource (or {})."""
+ return _ragLimitsHelper.getStoredOverrides(ds, "files")
@dataclass
@@ -143,12 +151,13 @@ async def bootstrapKdrive(
dsPath = ds.get("path", "")
dsId = ds.get("id", "")
dsNeutralize = ds.get("neutralize", False)
+ eff = _resolveDataSourceLimits(dsId, ds)
dsLimits = KdriveBootstrapLimits(
- maxItems=limits.maxItems,
- maxBytes=limits.maxBytes,
- maxFileSize=limits.maxFileSize,
+ maxItems=eff.get("maxItems", limits.maxItems),
+ maxBytes=eff.get("maxBytes", limits.maxBytes),
+ maxFileSize=eff.get("maxFileSize", limits.maxFileSize),
skipMimePrefixes=limits.skipMimePrefixes,
- maxDepth=limits.maxDepth,
+ maxDepth=eff.get("maxDepth", limits.maxDepth),
neutralize=dsNeutralize,
)
@@ -416,7 +425,11 @@ async def _ingestOne(
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
- progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
+ progressCb(
+ 0,
+ messageKey="{n} Dateien verarbeitet, {indexed} indexiert",
+ messageParams={"n": processed, "indexed": result.indexed},
+ )
except Exception:
pass
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py
index 17220d97..e676b156 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py
@@ -460,7 +460,11 @@ async def _ingestMessage(
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
- progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert")
+ progressCb(
+ 0,
+ messageKey="{n} Mails verarbeitet, {indexed} indexiert",
+ messageParams={"n": processed, "indexed": result.indexed},
+ )
except Exception:
pass
if processed % 50 == 0:
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
index e06fd36b..87c4c92a 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
@@ -30,14 +30,27 @@ from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
logger = logging.getLogger(__name__)
-MAX_ITEMS_DEFAULT = 500
-MAX_BYTES_DEFAULT = 200 * 1024 * 1024
-MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024
+from modules.serviceCenter.services.serviceKnowledge import _ragLimits as _ragLimitsHelper
+
+_FILES_DEFAULTS = _ragLimitsHelper.FILES_LIMITS_DEFAULT
+MAX_ITEMS_DEFAULT = _FILES_DEFAULTS["maxItems"]
+MAX_BYTES_DEFAULT = _FILES_DEFAULTS["maxBytes"]
+MAX_FILE_SIZE_DEFAULT = _FILES_DEFAULTS["maxFileSize"]
+MAX_DEPTH_DEFAULT = _FILES_DEFAULTS["maxDepth"]
SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/")
-MAX_DEPTH_DEFAULT = 4
MAX_SITES_DEFAULT = 3
+def _resolveDataSourceLimits(dsId: str, ds: Dict[str, Any]) -> Dict[str, int]:
+ """Return explicit RAG-limit overrides stored on the DataSource.
+
+ Empty dict means "use caller-supplied limits" — never overrides them with
+ defaults. Used to merge per-DataSource user settings on top of the
+ walker's runtime limits.
+ """
+ return _ragLimitsHelper.getStoredOverrides(ds, "files")
+
+
@dataclass
class SharepointBootstrapLimits:
maxItems: int = MAX_ITEMS_DEFAULT
@@ -165,12 +178,13 @@ async def bootstrapSharepoint(
dsPath = ds.get("path", "")
dsId = ds.get("id", "")
dsNeutralize = ds.get("neutralize", False)
+ eff = _resolveDataSourceLimits(dsId, ds)
dsLimits = SharepointBootstrapLimits(
- maxItems=limits.maxItems,
- maxBytes=limits.maxBytes,
- maxFileSize=limits.maxFileSize,
+ maxItems=eff.get("maxItems", limits.maxItems),
+ maxBytes=eff.get("maxBytes", limits.maxBytes),
+ maxFileSize=eff.get("maxFileSize", limits.maxFileSize),
skipMimePrefixes=limits.skipMimePrefixes,
- maxDepth=limits.maxDepth,
+ maxDepth=eff.get("maxDepth", limits.maxDepth),
maxSites=limits.maxSites,
neutralize=dsNeutralize,
)
@@ -441,7 +455,11 @@ async def _ingestOne(
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
- progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
+ progressCb(
+ 0,
+ messageKey="{n} Dateien verarbeitet, {indexed} indexiert",
+ messageParams={"n": processed, "indexed": result.indexed},
+ )
except Exception:
pass
if processed % 50 == 0:
diff --git a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py b/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py
index 10be150d..0deae777 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py
@@ -1,78 +1,32 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
-"""Resolve effective policies (neutralize, ragIndexEnabled) for DataSource tree hierarchies.
+"""DEPRECATED: Use `_inheritFlags.getEffectiveFlag()` directly.
-Tree-inheritance rule: nearest ancestor DataSource with an explicit value wins.
-If no ancestor has a value, the default (False) is used.
+Thin shim to the new cascade-inherit helper. Kept so external callers don't
+break on import — internal walkers consume pre-resolved dicts via
+`_loadRagEnabledDataSources`.
"""
from __future__ import annotations
-import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
-logger = logging.getLogger(__name__)
+from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
def resolveEffectiveNeutralize(
ds: Dict[str, Any],
allDataSources: List[Dict[str, Any]],
) -> bool:
- """Compute effective neutralize by walking up the path tree.
-
- A DataSource at /sites/HR/Documents inherits from /sites/HR if
- that ancestor has neutralize=True and the child has no explicit override.
- """
- ownValue = ds.get("neutralize")
- if ownValue is not None and ownValue is not False:
- return True
- if ownValue is False:
- return False
- return _findAncestorPolicy(ds, allDataSources, "neutralize")
+ """DEPRECATED: use `getEffectiveFlag(ds, 'neutralize', allDataSources)`."""
+ value = getEffectiveFlag(ds, "neutralize", allDataSources)
+ return bool(value)
def resolveEffectiveRagIndexEnabled(
ds: Dict[str, Any],
allDataSources: List[Dict[str, Any]],
) -> bool:
- """Compute effective ragIndexEnabled by walking up the path tree."""
- ownValue = ds.get("ragIndexEnabled")
- if ownValue is True:
- return True
- if ownValue is False:
- return False
- return _findAncestorPolicy(ds, allDataSources, "ragIndexEnabled")
-
-
-def _findAncestorPolicy(
- ds: Dict[str, Any],
- allDataSources: List[Dict[str, Any]],
- field: str,
-) -> bool:
- """Walk ancestors (longest-prefix match) to find an inherited policy value."""
- dsPath = ds.get("path", "")
- connectionId = ds.get("connectionId", "")
- if not dsPath:
- return False
-
- ancestors = []
- for candidate in allDataSources:
- if candidate.get("id") == ds.get("id"):
- continue
- if candidate.get("connectionId") != connectionId:
- continue
- candidatePath = candidate.get("path", "")
- if not candidatePath:
- continue
- if dsPath.startswith(candidatePath) and len(candidatePath) < len(dsPath):
- ancestors.append(candidate)
-
- ancestors.sort(key=lambda a: len(a.get("path", "")), reverse=True)
-
- for ancestor in ancestors:
- val = ancestor.get(field)
- if val is True:
- return True
- if val is False:
- return False
- return False
+ """DEPRECATED: use `getEffectiveFlag(ds, 'ragIndexEnabled', allDataSources)`."""
+ value = getEffectiveFlag(ds, "ragIndexEnabled", allDataSources)
+ return bool(value)
diff --git a/modules/shared/i18nRegistry.py b/modules/shared/i18nRegistry.py
index 7e620f8d..06ccb20e 100644
--- a/modules/shared/i18nRegistry.py
+++ b/modules/shared/i18nRegistry.py
@@ -124,6 +124,48 @@ def t(key: str, context: str = "api", value: str = "") -> str:
return _CACHE.get(lang, {}).get(key, f"[{key}]")
+def resolveJobMessage(messageData: Optional[Dict[str, Any]], lang: Optional[str] = None) -> Optional[str]:
+ """Translate a structured BackgroundJob progress payload.
+
+ ``messageData`` shape (written by ``JobProgressCallback`` when callers
+ pass ``messageKey`` / ``messageParams``)::
+
+ {"key": "{n} Dateien verarbeitet, {indexed} indexiert",
+ "params": {"n": 145, "indexed": 106}}
+
+ The walker call sites use a string-literal ``messageKey=``; the matching
+ ``t("…")`` literal lives in the feature's progress-key registration
+ module (e.g. ``serviceKnowledge/_progressMessages.py``,
+ ``features/trustee/mainTrustee.py``) so the boot sync picks it up.
+
+ This helper is the **server-side** translation hop so route handlers can
+ deliver a fully rendered ``progressMessage`` string to the frontend --
+ the frontend never calls ``t()`` on backend-supplied keys.
+ """
+ if not messageData or not isinstance(messageData, dict):
+ return None
+ key = messageData.get("key")
+ if not isinstance(key, str) or not key:
+ return None
+ params = messageData.get("params") or {}
+
+ if lang is not None:
+ token = _CURRENT_LANGUAGE.set(lang)
+ try:
+ template = t(key)
+ finally:
+ _CURRENT_LANGUAGE.reset(token)
+ else:
+ template = t(key)
+
+ if isinstance(params, dict) and params:
+ try:
+ return template.format(**params)
+ except (KeyError, IndexError, ValueError):
+ return template
+ return template
+
+
def resolveText(value: Any, lang: Optional[str] = None) -> str:
"""Resolve any value to a translated string for the current request language.
diff --git a/scripts/debug_rag_job_result.py b/scripts/debug_rag_job_result.py
new file mode 100644
index 00000000..c107f21e
--- /dev/null
+++ b/scripts/debug_rag_job_result.py
@@ -0,0 +1,70 @@
+"""Diagnose: read a connection.bootstrap job result and print its keys.
+
+Usage (from repo root):
+ python gateway\scripts\debug_rag_job_result.py
+
+Prints the most recent SUCCESS connection.bootstrap job per UserConnection so
+we can see whether the `stoppedAtLimit` key actually landed in the JSONB
+`result` column. If it is missing here, the bug is in the writer (handler or
+_markSuccess); if it is present here but absent in the HTTP response, the bug
+is in routeRagInventory.
+"""
+from __future__ import annotations
+
+import os
+import sys
+import json
+from pathlib import Path
+
+_HERE = Path(__file__).resolve()
+sys.path.insert(0, str(_HERE.parent.parent)) # gateway/
+os.chdir(_HERE.parent.parent)
+
+from modules.shared.configuration import APP_CONFIG # noqa: E402
+from modules.connectors.connectorDbPostgre import getCachedConnector # noqa: E402
+from modules.datamodels.datamodelBackgroundJob import BackgroundJob # noqa: E402
+from modules.routes.routeRagInventory import _flattenJobResult # noqa: E402
+
+
+def _main() -> None:
+ db = getCachedConnector(
+ dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
+ dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
+ dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
+ dbUser=APP_CONFIG.get("DB_USER"),
+ dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ )
+
+ rows = db.getRecordset(BackgroundJob)
+ rows = [r for r in rows if r.get("jobType") == "connection.bootstrap"]
+ rows = [r for r in rows if r.get("status") == "SUCCESS"]
+ rows.sort(key=lambda r: r.get("createdAt") or 0, reverse=True)
+
+ if not rows:
+ print("No SUCCESS connection.bootstrap jobs found.")
+ return
+
+ seenConnections: set[str] = set()
+ for j in rows:
+ connId = (j.get("payload") or {}).get("connectionId", "")
+ if connId in seenConnections:
+ continue
+ seenConnections.add(connId)
+ result = j.get("result") or {}
+ flat = _flattenJobResult(result) if isinstance(result, dict) else {}
+ print("=" * 80)
+ print(f"jobId = {j.get('id')}")
+ print(f"connectionId = {connId}")
+ print(f"finishedAt = {j.get('finishedAt')}")
+ print(f"raw keys = {sorted(result.keys()) if isinstance(result, dict) else 'N/A'}")
+ print("--- flattened (what the API will return now) ---")
+ print(f" indexed = {flat.get('indexed')}")
+ print(f" skippedDuplicate= {flat.get('skippedDuplicate')}")
+ print(f" skippedPolicy = {flat.get('skippedPolicy')}")
+ print(f" stoppedAtLimit = {flat.get('stoppedAtLimit')!r} <-- KEY CHECK")
+ print(f" limits = {flat.get('limits')}")
+ print(f" bytesProcessed = {flat.get('bytesProcessed')}")
+
+
+if __name__ == "__main__":
+ _main()
diff --git a/scripts/script_db_migrate_backgroundjob_progress_data.py b/scripts/script_db_migrate_backgroundjob_progress_data.py
new file mode 100644
index 00000000..bc5fc348
--- /dev/null
+++ b/scripts/script_db_migrate_backgroundjob_progress_data.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+"""Migration: Add `progressMessageData` JSONB column to BackgroundJob.
+
+Carries the structured i18n payload that lets the frontend translate
+walker progress messages (e.g. "{n} Dateien verarbeitet, {indexed}
+indexiert") into the user's UI language. `progressMessage` stays around
+as the rendered fallback for older clients and audit logs.
+
+Safe to run multiple times (checks column existence before acting).
+
+Usage:
+ python scripts/script_db_migrate_backgroundjob_progress_data.py [--dry-run]
+"""
+
+import os
+import sys
+import argparse
+import logging
+from pathlib import Path
+
+scriptPath = Path(__file__).resolve()
+gatewayPath = scriptPath.parent.parent
+sys.path.insert(0, str(gatewayPath))
+os.chdir(str(gatewayPath))
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
+logger = logging.getLogger(__name__)
+
+import psycopg2
+from modules.shared.configuration import APP_CONFIG
+
+
+def _getConnection():
+ return psycopg2.connect(
+ host=APP_CONFIG.get("DB_HOST", "localhost"),
+ port=int(APP_CONFIG.get("DB_PORT", "5432")),
+ database=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
+ user=APP_CONFIG.get("DB_USER"),
+ password=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ )
+
+
+def _columnExists(cur, table: str, column: str) -> bool:
+ cur.execute(
+ """SELECT 1 FROM information_schema.columns
+ WHERE table_schema = 'public' AND table_name = %s AND column_name = %s""",
+ (table, column),
+ )
+ return cur.fetchone() is not None
+
+
+def _tableExists(cur, table: str) -> bool:
+ cur.execute(
+ """SELECT 1 FROM information_schema.tables
+ WHERE table_schema = 'public' AND table_name = %s""",
+ (table,),
+ )
+ return cur.fetchone() is not None
+
+
+def migrate(dryRun: bool = False):
+ conn = _getConnection()
+ conn.autocommit = False
+ cur = conn.cursor()
+
+ table, column = "BackgroundJob", "progressMessageData"
+ executed = []
+
+ if not _tableExists(cur, table):
+ logger.warning("SKIP: table %s does not exist yet (will be created on next ORM init)", table)
+ elif _columnExists(cur, table, column):
+ logger.info("SKIP: %s.%s already exists", table, column)
+ else:
+ sql = f'ALTER TABLE public."{table}" ADD COLUMN "{column}" JSONB DEFAULT NULL;'
+ logger.info("EXEC: %s", sql)
+ if not dryRun:
+ cur.execute(sql)
+ executed.append(sql)
+
+ if not dryRun and executed:
+ conn.commit()
+ logger.info("Migration committed (%d statements)", len(executed))
+ elif dryRun and executed:
+ conn.rollback()
+ logger.info("DRY RUN -- would execute %d statements", len(executed))
+ else:
+ logger.info("Nothing to do -- schema already up to date")
+
+ cur.close()
+ conn.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("--dry-run", action="store_true", help="Print SQL without executing")
+ args = parser.parse_args()
+ migrate(dryRun=args.dry_run)
diff --git a/scripts/script_db_migrate_datasource_inherit.py b/scripts/script_db_migrate_datasource_inherit.py
new file mode 100644
index 00000000..3444cbee
--- /dev/null
+++ b/scripts/script_db_migrate_datasource_inherit.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python3
+"""Migration: Drop NOT NULL on DataSource/FeatureDataSource cascade-inherit flags.
+
+Switches three-valued semantics (NULL = inherit, True/False = explicit) for:
+ - DataSource.neutralize, ragIndexEnabled, scope
+ - FeatureDataSource.neutralize, scope
+
+Existing rows keep their explicit values; only new records (or explicit reset
+via cascade) start with NULL. Migration is non-destructive and idempotent.
+
+Safe to run multiple times.
+
+Usage:
+ python scripts/script_db_migrate_datasource_inherit.py [--dry-run]
+"""
+
+import os
+import sys
+import argparse
+import logging
+from pathlib import Path
+
+scriptPath = Path(__file__).resolve()
+gatewayPath = scriptPath.parent.parent
+sys.path.insert(0, str(gatewayPath))
+os.chdir(str(gatewayPath))
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
+logger = logging.getLogger(__name__)
+
+import psycopg2
+from modules.shared.configuration import APP_CONFIG
+
+
+def _getConnection():
+ return psycopg2.connect(
+ host=APP_CONFIG.get("DB_HOST", "localhost"),
+ port=int(APP_CONFIG.get("DB_PORT", "5432")),
+ database=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
+ user=APP_CONFIG.get("DB_USER"),
+ password=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ )
+
+
+def _tableExists(cur, table: str) -> bool:
+ cur.execute(
+ """SELECT 1 FROM information_schema.tables
+ WHERE table_schema = 'public' AND table_name = %s""",
+ (table,),
+ )
+ return cur.fetchone() is not None
+
+
+def _columnIsNullable(cur, table: str, column: str) -> bool:
+ cur.execute(
+ """SELECT is_nullable FROM information_schema.columns
+ WHERE table_schema = 'public' AND table_name = %s AND column_name = %s""",
+ (table, column),
+ )
+ row = cur.fetchone()
+ if not row:
+ return False
+ return row[0] == "YES"
+
+
+def migrate(dryRun: bool = False):
+ conn = _getConnection()
+ conn.autocommit = False
+ cur = conn.cursor()
+
+ targets = [
+ ("DataSource", "neutralize"),
+ ("DataSource", "ragIndexEnabled"),
+ ("DataSource", "scope"),
+ ("FeatureDataSource", "neutralize"),
+ ("FeatureDataSource", "scope"),
+ ]
+
+ executed = []
+ for table, column in targets:
+ if not _tableExists(cur, table):
+ logger.warning("SKIP: table %s does not exist yet", table)
+ continue
+ if _columnIsNullable(cur, table, column):
+ logger.info("SKIP: %s.%s already nullable", table, column)
+ continue
+ sql = f'ALTER TABLE public."{table}" ALTER COLUMN "{column}" DROP NOT NULL;'
+ logger.info("EXEC: %s", sql)
+ if not dryRun:
+ cur.execute(sql)
+ executed.append(sql)
+
+ if not dryRun and executed:
+ conn.commit()
+ logger.info("Migration committed (%d statements)", len(executed))
+ elif dryRun and executed:
+ conn.rollback()
+ logger.info("DRY RUN -- would execute %d statements", len(executed))
+ else:
+ logger.info("Nothing to do -- schema already nullable")
+
+ cur.close()
+ conn.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("--dry-run", action="store_true", help="Print SQL without executing")
+ args = parser.parse_args()
+ migrate(dryRun=args.dry_run)
diff --git a/scripts/script_db_migrate_datasource_settings.py b/scripts/script_db_migrate_datasource_settings.py
new file mode 100644
index 00000000..9e821221
--- /dev/null
+++ b/scripts/script_db_migrate_datasource_settings.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+"""Migration: Add `settings` JSONB column to DataSource and FeatureDataSource.
+
+This is a one-off migration for the UDB DataSource Settings (Settings-Icon)
+feature: walkers read RAG limits (maxBytes, maxFileSize, maxItems, maxDepth)
+from this JSON blob, the UI edits them. Existing rows get NULL until the
+next bootstrap lazy-fills sensible defaults from `_ragLimits.RAG_LIMITS_DEFAULT`.
+
+Safe to run multiple times (checks column existence before acting).
+
+Usage:
+ python scripts/script_db_migrate_datasource_settings.py [--dry-run]
+"""
+
+import os
+import sys
+import argparse
+import logging
+from pathlib import Path
+
+scriptPath = Path(__file__).resolve()
+gatewayPath = scriptPath.parent.parent
+sys.path.insert(0, str(gatewayPath))
+os.chdir(str(gatewayPath))
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
+logger = logging.getLogger(__name__)
+
+import psycopg2
+from modules.shared.configuration import APP_CONFIG
+
+
+def _getConnection():
+ return psycopg2.connect(
+ host=APP_CONFIG.get("DB_HOST", "localhost"),
+ port=int(APP_CONFIG.get("DB_PORT", "5432")),
+ database=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
+ user=APP_CONFIG.get("DB_USER"),
+ password=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ )
+
+
+def _columnExists(cur, table: str, column: str) -> bool:
+ cur.execute(
+ """SELECT 1 FROM information_schema.columns
+ WHERE table_schema = 'public' AND table_name = %s AND column_name = %s""",
+ (table, column),
+ )
+ return cur.fetchone() is not None
+
+
+def _tableExists(cur, table: str) -> bool:
+ cur.execute(
+ """SELECT 1 FROM information_schema.tables
+ WHERE table_schema = 'public' AND table_name = %s""",
+ (table,),
+ )
+ return cur.fetchone() is not None
+
+
+def migrate(dryRun: bool = False):
+ conn = _getConnection()
+ conn.autocommit = False
+ cur = conn.cursor()
+
+ targets = [
+ ("DataSource", "settings"),
+ ("FeatureDataSource", "settings"),
+ ]
+
+ executed = []
+ for table, column in targets:
+ if not _tableExists(cur, table):
+ logger.warning("SKIP: table %s does not exist yet (will be created on next ORM init)", table)
+ continue
+ if _columnExists(cur, table, column):
+ logger.info("SKIP: %s.%s already exists", table, column)
+ continue
+ sql = f'ALTER TABLE public."{table}" ADD COLUMN "{column}" JSONB DEFAULT NULL;'
+ logger.info("EXEC: %s", sql)
+ if not dryRun:
+ cur.execute(sql)
+ executed.append(sql)
+
+ if not dryRun and executed:
+ conn.commit()
+ logger.info("Migration committed (%d statements)", len(executed))
+ elif dryRun and executed:
+ conn.rollback()
+ logger.info("DRY RUN -- would execute %d statements", len(executed))
+ else:
+ logger.info("Nothing to do -- schema already up to date")
+
+ cur.close()
+ conn.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("--dry-run", action="store_true", help="Print SQL without executing")
+ args = parser.parse_args()
+ migrate(dryRun=args.dry_run)
diff --git a/tests/unit/services/test_costEstimate.py b/tests/unit/services/test_costEstimate.py
new file mode 100644
index 00000000..e49aca6a
--- /dev/null
+++ b/tests/unit/services/test_costEstimate.py
@@ -0,0 +1,55 @@
+"""Unit tests for `_costEstimate` heuristic.
+
+Validates the output shape, basic formulas, and that 'basis' annotations
+are always present (the user-facing transparency contract).
+"""
+
+from __future__ import annotations
+
+import unittest
+
+from modules.serviceCenter.services.serviceKnowledge import _costEstimate
+
+
+class TestCostEstimate(unittest.TestCase):
+ def test_files_shape(self):
+ result = _costEstimate.estimateBootstrapCost(
+ {"maxBytes": 200 * 1024 * 1024}, kind="files",
+ )
+ self.assertIn("estimatedTokens", result)
+ self.assertIn("estimatedUsd", result)
+ self.assertIn("basis", result)
+ self.assertIn("assumptions", result["basis"])
+ self.assertIn("formula", result["basis"]["assumptions"])
+ self.assertIn("notes", result["basis"])
+
+ def test_files_doubling_maxBytes_doubles_tokens(self):
+ low = _costEstimate.estimateBootstrapCost({"maxBytes": 100 * 1024 * 1024}, kind="files")
+ high = _costEstimate.estimateBootstrapCost({"maxBytes": 200 * 1024 * 1024}, kind="files")
+ self.assertEqual(high["estimatedTokens"], low["estimatedTokens"] * 2)
+
+ def test_clickup_uses_tasks_and_workspaces(self):
+ result = _costEstimate.estimateBootstrapCost(
+ {"maxTasks": 100, "maxWorkspaces": 2, "maxListsPerWorkspace": 10},
+ kind="clickup",
+ )
+ expectedTokens = 100 * 2 * _costEstimate.DEFAULT_TOKENS_PER_ITEM
+ self.assertEqual(result["estimatedTokens"], expectedTokens)
+
+ def test_unknown_kind_returns_zero(self):
+ result = _costEstimate.estimateBootstrapCost({}, kind="totally-unknown")
+ self.assertEqual(result["estimatedTokens"], 0)
+ self.assertEqual(result["estimatedUsd"], 0.0)
+
+ def test_usd_is_rounded_4_decimals(self):
+ result = _costEstimate.estimateBootstrapCost({"maxBytes": 1024 * 1024}, kind="files")
+ rounded = round(result["estimatedUsd"], 4)
+ self.assertEqual(result["estimatedUsd"], rounded)
+
+ def test_basis_includes_input_limits(self):
+ result = _costEstimate.estimateBootstrapCost({"maxBytes": 42}, kind="files")
+ self.assertEqual(result["basis"]["limits"]["maxBytes"], 42)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit/services/test_inheritFlags.py b/tests/unit/services/test_inheritFlags.py
new file mode 100644
index 00000000..b177e767
--- /dev/null
+++ b/tests/unit/services/test_inheritFlags.py
@@ -0,0 +1,330 @@
+"""Unit tests for `_inheritFlags` cascade-inherit helpers.
+
+Verifies:
+- getEffectiveFlag walks ancestors via path-prefix matching
+- root default is False (or 'personal' for scope) when nothing explicit in chain
+- only same-connectionId AND same-sourceType ancestors are considered
+- cascadeResetDescendants only touches descendants with explicit values for THAT flag
+- '/' is treated as ancestor of every non-root path
+- '/foo' is NOT ancestor of '/foobar' (must require '/' separator)
+"""
+
+from __future__ import annotations
+
+import unittest
+from typing import List
+from unittest.mock import MagicMock
+
+from modules.serviceCenter.services.serviceKnowledge import _inheritFlags
+
+
+def _ds(idVal: str, path: str, **flags) -> dict:
+ """Build a DataSource dict with sensible defaults for a fixture."""
+ base = {
+ "id": idVal,
+ "connectionId": "conn-1",
+ "sourceType": "sharepointFolder",
+ "path": path,
+ "neutralize": None,
+ "ragIndexEnabled": None,
+ "scope": None,
+ }
+ base.update(flags)
+ return base
+
+
+class TestEffectiveFlag(unittest.TestCase):
+ def test_explicit_own_value_wins(self):
+ root = _ds("r", "/", neutralize=False)
+ leaf = _ds("l", "/folder/sub", neutralize=True)
+ self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf]))
+
+ def test_inherits_from_root_when_own_is_none(self):
+ root = _ds("r", "/", neutralize=True)
+ leaf = _ds("l", "/folder/sub")
+ self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf]))
+
+ def test_default_false_when_chain_empty(self):
+ leaf = _ds("l", "/folder/sub")
+ self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [leaf]))
+
+ def test_nearest_ancestor_wins_over_distant(self):
+ root = _ds("r", "/", neutralize=False)
+ mid = _ds("m", "/folder", neutralize=True)
+ leaf = _ds("l", "/folder/sub")
+ self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, mid, leaf]))
+
+ def test_different_connection_ignored(self):
+ otherConn = _ds("o", "/", connectionId="conn-2", neutralize=True)
+ leaf = _ds("l", "/folder")
+ self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherConn, leaf]))
+
+ def test_different_sourcetype_ignored(self):
+ otherType = _ds("o", "/", sourceType="outlookFolder", neutralize=True)
+ leaf = _ds("l", "/folder")
+ self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherType, leaf]))
+
+ def test_path_separator_required(self):
+ """`/foo` must NOT be ancestor of `/foobar` (no shared `/` boundary)."""
+ notAncestor = _ds("a", "/foo", neutralize=True)
+ leaf = _ds("l", "/foobar")
+ self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [notAncestor, leaf]))
+
+ def test_root_is_ancestor_of_everything(self):
+ root = _ds("r", "/", neutralize=True)
+ leaf = _ds("l", "/anything/anywhere")
+ self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf]))
+
+ def test_scope_inheritance_with_string_default(self):
+ root = _ds("r", "/", scope="mandate")
+ leaf = _ds("l", "/folder")
+ self.assertEqual(_inheritFlags.getEffectiveFlag(leaf, "scope", [root, leaf]), "mandate")
+
+ def test_scope_default_personal_when_empty(self):
+ leaf = _ds("l", "/folder")
+ self.assertEqual(_inheritFlags.getEffectiveFlag(leaf, "scope", [leaf]), "personal")
+
+ def test_unknown_flag_raises(self):
+ leaf = _ds("l", "/")
+ with self.assertRaises(ValueError):
+ _inheritFlags.getEffectiveFlag(leaf, "unknownFlag", [leaf])
+
+ def test_explicit_false_overrides_inherited_true(self):
+ """Explicit False on a child must NOT cascade up to True from an ancestor."""
+ root = _ds("r", "/", neutralize=True)
+ leaf = _ds("l", "/folder", neutralize=False)
+ self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf]))
+
+ def test_connection_root_inherits_cross_sourcetype(self):
+ """Connection-root (sourceType=authority, path='/') is ancestor of all DS in that connection."""
+ connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
+ spService = _ds("sp", "/", sourceType="sharepointFolder")
+ olService = _ds("ol", "/", sourceType="outlookFolder")
+ self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", [connRoot, spService, olService]))
+ self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", [connRoot, spService, olService]))
+
+ def test_same_sourcetype_ancestor_wins_over_connection_root(self):
+ """A same-sourceType service-root ancestor beats the connection-root."""
+ connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
+ spRoot = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
+ spLeaf = _ds("spl", "/sites/x", sourceType="sharepointFolder")
+ self.assertFalse(_inheritFlags.getEffectiveFlag(spLeaf, "neutralize", [connRoot, spRoot, spLeaf]))
+
+ def test_connection_root_does_not_self_inherit(self):
+ """Connection-root has no ancestor — does not infinite-loop on itself."""
+ connRoot = _ds("conn", "/", sourceType="msft")
+ self.assertFalse(_inheritFlags.getEffectiveFlag(connRoot, "neutralize", [connRoot]))
+
+
+class TestCascadeReset(unittest.TestCase):
+ def _makeRootIf(self, dataSources: List[dict]):
+ rootIf = MagicMock()
+ rootIf.db.getRecordset = MagicMock(return_value=dataSources)
+ modified = []
+
+ def _modify(model, recordId, fields):
+ modified.append((recordId, fields))
+ rootIf.db.recordModify = MagicMock(side_effect=_modify)
+ return rootIf, modified
+
+ def test_resets_only_explicit_descendants(self):
+ parent = _ds("p", "/sites", neutralize=True)
+ explicitChild = _ds("c1", "/sites/folder1", neutralize=False)
+ inheritChild = _ds("c2", "/sites/folder2") # inherit -> not touched
+ sibling = _ds("s", "/other", neutralize=True) # NOT a descendant
+ rootIf, modified = self._makeRootIf([parent, explicitChild, inheritChild, sibling])
+
+ affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+
+ self.assertEqual(affected, 1)
+ self.assertEqual(modified, [("c1", {"neutralize": None})])
+
+ def test_does_not_touch_other_flags(self):
+ parent = _ds("p", "/sites", neutralize=True)
+ child = _ds("c", "/sites/sub", neutralize=False, ragIndexEnabled=True)
+ rootIf, modified = self._makeRootIf([parent, child])
+
+ _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+
+ self.assertEqual(modified, [("c", {"neutralize": None})])
+ # ragIndexEnabled and scope on the child must remain untouched.
+
+ def test_does_not_cross_sourcetype(self):
+ """Non-connection-root parents stay within their sourceType for cascade."""
+ parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder")
+ otherTypeDescendant = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder")
+ rootIf, modified = self._makeRootIf([parent, otherTypeDescendant])
+
+ affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+
+ self.assertEqual(affected, 0)
+ self.assertEqual(modified, [])
+
+ def test_connection_root_cascades_cross_sourcetype(self):
+ """Toggle on connection-root cascades into every explicit DS of that connection."""
+ connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
+ spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
+ olInherit = _ds("ol", "/", sourceType="outlookFolder")
+ spLeafExplicit = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True)
+ rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeafExplicit])
+
+ affected = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize")
+
+ # spExplicit and spLeafExplicit had explicit values → reset. olInherit untouched.
+ self.assertEqual(affected, 2)
+ self.assertEqual({m[0] for m in modified}, {"sp", "sp-leaf"})
+ for _, fields in modified:
+ self.assertEqual(fields, {"neutralize": None})
+
+ def test_unknown_flag_raises(self):
+ parent = _ds("p", "/", neutralize=True)
+ rootIf, _ = self._makeRootIf([parent])
+ with self.assertRaises(ValueError):
+ _inheritFlags.cascadeResetDescendants(rootIf, parent, "unknownFlag")
+
+
+def _fds(idVal: str, *, tableName: str, recordFilter=None, **flags) -> dict:
+ """Build a FeatureDataSource dict fixture."""
+ base = {
+ "id": idVal,
+ "workspaceInstanceId": "ws-1",
+ "tableName": tableName,
+ "recordFilter": recordFilter,
+ "neutralize": None,
+ "scope": None,
+ }
+ base.update(flags)
+ return base
+
+
+class TestFdsClassifyAndAncestry(unittest.TestCase):
+ def test_classify_workspace_wildcard(self):
+ self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="*")), "workspace")
+
+ def test_classify_table_wildcard(self):
+ self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="Pos")), "table")
+
+ def test_classify_record_specific(self):
+ rec = _fds("a", tableName="Pos", recordFilter={"id": "r-1"})
+ self.assertEqual(_inheritFlags._fdsClassify(rec), "record")
+
+ def test_workspace_is_ancestor_of_table_and_record(self):
+ ws = _fds("ws", tableName="*")
+ tbl = _fds("t", tableName="Pos")
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ self.assertTrue(_inheritFlags._fdsIsAncestor(ws, tbl))
+ self.assertTrue(_inheritFlags._fdsIsAncestor(ws, rec))
+
+ def test_table_is_ancestor_of_record_same_table_only(self):
+ tbl = _fds("t", tableName="Pos")
+ recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"})
+ recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"})
+ self.assertTrue(_inheritFlags._fdsIsAncestor(tbl, recSame))
+ self.assertFalse(_inheritFlags._fdsIsAncestor(tbl, recOther))
+
+ def test_record_has_no_descendants(self):
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ tbl = _fds("t", tableName="Pos")
+ self.assertFalse(_inheritFlags._fdsIsAncestor(rec, tbl))
+
+ def test_no_cross_workspace_ancestry(self):
+ ws = _fds("ws", tableName="*", workspaceInstanceId="ws-A")
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, workspaceInstanceId="ws-B")
+ self.assertFalse(_inheritFlags._fdsIsAncestor(ws, rec))
+
+
+class TestFdsEffectiveFlag(unittest.TestCase):
+ def test_own_explicit_wins(self):
+ ws = _fds("ws", tableName="*", neutralize=False)
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
+ self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [ws, rec]))
+
+ def test_inherits_from_table_wildcard(self):
+ tbl = _fds("t", tableName="Pos", neutralize=True)
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [tbl, rec]))
+
+ def test_table_wildcard_beats_workspace_wildcard(self):
+ ws = _fds("ws", tableName="*", neutralize=False)
+ tbl = _fds("t", tableName="Pos", neutralize=True)
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [ws, tbl, rec]))
+
+ def test_workspace_wildcard_inherits_when_no_table(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [ws, rec]))
+
+ def test_default_false_when_chain_empty(self):
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ self.assertFalse(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [rec]))
+
+ def test_unknown_flag_raises(self):
+ rec = _fds("r", tableName="*")
+ with self.assertRaises(ValueError):
+ _inheritFlags.getEffectiveFlagFds(rec, "ragIndexEnabled", [rec])
+
+
+class TestFdsCascadeReset(unittest.TestCase):
+ def _makeRootIf(self, fdses):
+ rootIf = MagicMock()
+ rootIf.db.getRecordset = MagicMock(return_value=fdses)
+ modified = []
+
+ def _modify(model, recordId, fields):
+ modified.append((recordId, fields))
+ rootIf.db.recordModify = MagicMock(side_effect=_modify)
+ return rootIf, modified
+
+ def test_workspace_cascades_to_all_explicit_descendants(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ tblExplicit = _fds("t", tableName="Pos", neutralize=False)
+ tblInherit = _fds("t2", tableName="Other")
+ recExplicit = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
+ rootIf, modified = self._makeRootIf([ws, tblExplicit, tblInherit, recExplicit])
+
+ affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize")
+
+ self.assertEqual(affected, 2)
+ self.assertEqual({m[0] for m in modified}, {"t", "r"})
+
+ def test_table_cascades_only_to_same_table_records(self):
+ tbl = _fds("t", tableName="Pos", neutralize=True)
+ recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"}, neutralize=False)
+ recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}, neutralize=False)
+ rootIf, modified = self._makeRootIf([tbl, recSame, recOther])
+
+ affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize")
+
+ self.assertEqual(affected, 1)
+ self.assertEqual(modified, [("r1", {"neutralize": None})])
+
+ def test_record_has_no_cascade(self):
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
+ rootIf, modified = self._makeRootIf([rec])
+ affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize")
+ self.assertEqual(affected, 0)
+ self.assertEqual(modified, [])
+
+ def test_unknown_flag_raises(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ rootIf, _ = self._makeRootIf([ws])
+ with self.assertRaises(ValueError):
+ _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled")
+
+
+class TestPathNormalization(unittest.TestCase):
+ def test_empty_path_normalises_to_root(self):
+ self.assertEqual(_inheritFlags._normalisePath(""), "/")
+ self.assertEqual(_inheritFlags._normalisePath(None), "/")
+
+ def test_trailing_slash_stripped(self):
+ self.assertEqual(_inheritFlags._normalisePath("/foo/"), "/foo")
+ self.assertEqual(_inheritFlags._normalisePath("/"), "/")
+
+ def test_leading_slash_added(self):
+ self.assertEqual(_inheritFlags._normalisePath("foo/bar"), "/foo/bar")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit/services/test_knowledge_ingest_consumer.py b/tests/unit/services/test_knowledge_ingest_consumer.py
index 6b27a6e8..9884079e 100644
--- a/tests/unit/services/test_knowledge_ingest_consumer.py
+++ b/tests/unit/services/test_knowledge_ingest_consumer.py
@@ -99,11 +99,18 @@ def test_onConnectionRevoked_ignores_missing_id(monkeypatch):
assert seen == []
+def _stubRagEnabledDs(monkeypatch, dataSources):
+ """Stub _loadRagEnabledDataSources so tests don't need a live DB."""
+ monkeypatch.setattr(consumer, "_loadRagEnabledDataSources", lambda *_, **__: dataSources)
+
+
def test_bootstrap_job_skips_unsupported_authority(monkeypatch):
+ _stubRagEnabledDs(monkeypatch, [{"id": "ds1", "sourceType": "unknownType"}])
+
async def _run():
result = await consumer._bootstrapJobHandler(
{"payload": {"connectionId": "c1", "authority": "slack"}},
- lambda *_: None,
+ lambda *_, **__: None,
)
return result
@@ -114,13 +121,18 @@ def test_bootstrap_job_skips_unsupported_authority(monkeypatch):
def test_bootstrap_job_dispatches_msft_parts(monkeypatch):
+ _stubRagEnabledDs(monkeypatch, [
+ {"id": "ds1", "sourceType": "sharepointFolder"},
+ {"id": "ds2", "sourceType": "outlookFolder"},
+ ])
+
calls = {"sp": 0, "ol": 0}
- async def _fakeSp(connectionId, progressCb=None):
+ async def _fakeSp(connectionId, progressCb=None, dataSources=None):
calls["sp"] += 1
return {"indexed": 1}
- async def _fakeOl(connectionId, progressCb=None):
+ async def _fakeOl(connectionId, progressCb=None, dataSources=None):
calls["ol"] += 1
return {"indexed": 2}
@@ -142,7 +154,7 @@ def test_bootstrap_job_dispatches_msft_parts(monkeypatch):
async def _run():
return await consumer._bootstrapJobHandler(
{"payload": {"connectionId": "c1", "authority": "msft"}},
- lambda *_: None,
+ lambda *_, **__: None,
)
result = asyncio.run(_run())
@@ -152,13 +164,18 @@ def test_bootstrap_job_dispatches_msft_parts(monkeypatch):
def test_bootstrap_job_dispatches_google_parts(monkeypatch):
+ _stubRagEnabledDs(monkeypatch, [
+ {"id": "ds1", "sourceType": "googleDriveFolder"},
+ {"id": "ds2", "sourceType": "gmailFolder"},
+ ])
+
calls = {"gd": 0, "gm": 0}
- async def _fakeGd(connectionId, progressCb=None):
+ async def _fakeGd(connectionId, progressCb=None, dataSources=None):
calls["gd"] += 1
return {"indexed": 7}
- async def _fakeGm(connectionId, progressCb=None):
+ async def _fakeGm(connectionId, progressCb=None, dataSources=None):
calls["gm"] += 1
return {"indexed": 11}
@@ -180,7 +197,7 @@ def test_bootstrap_job_dispatches_google_parts(monkeypatch):
async def _run():
return await consumer._bootstrapJobHandler(
{"payload": {"connectionId": "c1", "authority": "google"}},
- lambda *_: None,
+ lambda *_, **__: None,
)
result = asyncio.run(_run())
@@ -190,9 +207,13 @@ def test_bootstrap_job_dispatches_google_parts(monkeypatch):
def test_bootstrap_job_dispatches_clickup_part(monkeypatch):
+ _stubRagEnabledDs(monkeypatch, [
+ {"id": "ds1", "sourceType": "clickupList"},
+ ])
+
calls = {"cu": 0}
- async def _fakeCu(connectionId, progressCb=None):
+ async def _fakeCu(connectionId, progressCb=None, dataSources=None):
calls["cu"] += 1
return {"indexed": 4}
@@ -207,7 +228,7 @@ def test_bootstrap_job_dispatches_clickup_part(monkeypatch):
async def _run():
return await consumer._bootstrapJobHandler(
{"payload": {"connectionId": "c1", "authority": "clickup"}},
- lambda *_: None,
+ lambda *_, **__: None,
)
result = asyncio.run(_run())
diff --git a/tests/unit/services/test_ragLimits.py b/tests/unit/services/test_ragLimits.py
new file mode 100644
index 00000000..bb336ed3
--- /dev/null
+++ b/tests/unit/services/test_ragLimits.py
@@ -0,0 +1,79 @@
+"""Unit tests for `_ragLimits` central helpers.
+
+Verifies:
+- defaults are returned as fresh copies (no mutation leakage)
+- getStoredOverrides returns ONLY explicit overrides (walker contract)
+- getRagLimits merges defaults with overrides (API/cost-estimate contract)
+- non-int values in stored settings are dropped, not silently coerced
+"""
+
+from __future__ import annotations
+
+import unittest
+
+from modules.serviceCenter.services.serviceKnowledge import _ragLimits
+
+
+class TestGetDefaults(unittest.TestCase):
+ def test_files_defaults_have_all_keys(self):
+ d = _ragLimits.getDefaults("files")
+ self.assertEqual(set(d.keys()), {"maxItems", "maxBytes", "maxFileSize", "maxDepth"})
+ self.assertEqual(d["maxBytes"], 200 * 1024 * 1024)
+
+ def test_clickup_defaults(self):
+ d = _ragLimits.getDefaults("clickup")
+ self.assertEqual(set(d.keys()), {"maxTasks", "maxWorkspaces", "maxListsPerWorkspace"})
+
+ def test_defaults_are_a_fresh_copy(self):
+ d1 = _ragLimits.getDefaults("files")
+ d1["maxBytes"] = 1
+ d2 = _ragLimits.getDefaults("files")
+ self.assertEqual(d2["maxBytes"], 200 * 1024 * 1024)
+
+ def test_unknown_kind_raises(self):
+ with self.assertRaises(ValueError):
+ _ragLimits.getDefaults("unknown")
+
+
+class TestGetStoredOverrides(unittest.TestCase):
+ def test_no_settings_returns_empty_dict(self):
+ self.assertEqual(_ragLimits.getStoredOverrides({"id": "x", "settings": None}, "files"), {})
+
+ def test_only_explicit_overrides_returned(self):
+ ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999}}}
+ self.assertEqual(_ragLimits.getStoredOverrides(ds, "files"), {"maxBytes": 999})
+
+ def test_unknown_keys_dropped(self):
+ ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999, "bogus": 1}}}
+ self.assertEqual(_ragLimits.getStoredOverrides(ds, "files"), {"maxBytes": 999})
+
+ def test_non_int_dropped(self):
+ ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": "not-a-number"}}}
+ self.assertEqual(_ragLimits.getStoredOverrides(ds, "files"), {})
+
+ def test_none_or_garbage_settings_safe(self):
+ self.assertEqual(_ragLimits.getStoredOverrides(None, "files"), {})
+ self.assertEqual(_ragLimits.getStoredOverrides({"id": "x", "settings": "garbage"}, "files"), {})
+
+
+class TestGetRagLimits(unittest.TestCase):
+ def test_no_settings_returns_defaults(self):
+ result = _ragLimits.getRagLimits({"id": "x", "settings": None}, "files")
+ self.assertEqual(result, _ragLimits.FILES_LIMITS_DEFAULT)
+
+ def test_partial_override_merges_with_defaults(self):
+ ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999}}}
+ result = _ragLimits.getRagLimits(ds, "files")
+ self.assertEqual(result["maxBytes"], 999)
+ self.assertEqual(result["maxItems"], _ragLimits.FILES_LIMITS_DEFAULT["maxItems"])
+
+ def test_caller_can_distinguish_unset_from_set(self):
+ """Walker contract: an unset key MUST NOT appear in `getStoredOverrides`."""
+ ds = {"id": "x", "settings": {"ragLimits": {"maxBytes": 999}}}
+ overrides = _ragLimits.getStoredOverrides(ds, "files")
+ self.assertIn("maxBytes", overrides)
+ self.assertNotIn("maxItems", overrides)
+
+
+if __name__ == "__main__":
+ unittest.main()
From 1ed462ad13742c7bf4b52a5abb7a83395a2ddb4c Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Tue, 19 May 2026 16:48:01 +0200
Subject: [PATCH 3/6] fixes rag and workflow
---
modules/connectors/connectorDbPostgre.py | 16 +-
.../datamodels/datamodelFeatureDataSource.py | 8 +
modules/demoConfigs/investorDemo2026.py | 25 +-
modules/demoConfigs/pwgDemo2026.py | 23 +-
.../workspace/datamodelFeatureWorkspace.py | 9 +-
.../workspace/routeFeatureWorkspace.py | 625 +++-------
modules/routes/routeAdminDemoConfig.py | 14 +-
modules/routes/routeDataConnections.py | 7 +-
modules/routes/routeDataFiles.py | 261 ++++-
modules/routes/routeDataSources.py | 194 +++-
modules/routes/routeRagInventory.py | 271 ++++-
modules/security/rbac.py | 16 +-
.../coreTools/_dataSourceTools.py | 7 +-
.../coreTools/_featureSubAgentTools.py | 6 +-
.../serviceAgent/featureDataProvider.py | 42 +-
.../services/serviceKnowledge/_buildTree.py | 1020 +++++++++++++++++
.../serviceKnowledge/_inheritFlags.py | 503 ++++++--
.../serviceKnowledge/mainServiceKnowledge.py | 2 +-
.../subConnectorIngestConsumer.py | 11 +-
.../serviceKnowledge/subFeatureBootstrap.py | 289 +++++
.../serviceKnowledge/subPolicyResolver.py | 32 -
.../serviceKnowledge/subWalkerHelpers.py | 17 +-
scripts/script_migrate_user_uid.py | 274 +++++
.../test_connectorDbPostgre_failLoud.py | 10 +
tests/unit/services/test_buildTree.py | 359 ++++++
tests/unit/services/test_inheritFlags.py | 525 +++++++--
tests/unit/teamsbot/test_directorPrompts.py | 10 +-
27 files changed, 3728 insertions(+), 848 deletions(-)
create mode 100644 modules/serviceCenter/services/serviceKnowledge/_buildTree.py
create mode 100644 modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py
delete mode 100644 modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py
create mode 100644 scripts/script_migrate_user_uid.py
create mode 100644 tests/unit/services/test_buildTree.py
diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py
index f1a34f70..fa4cba44 100644
--- a/modules/connectors/connectorDbPostgre.py
+++ b/modules/connectors/connectorDbPostgre.py
@@ -172,7 +172,7 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s
pass # already a list
elif fieldType == "BOOLEAN":
- record[fieldName] = bool(value) if value is not None else False
+ record[fieldName] = bool(value) if value is not None else None
elif fieldType == "JSONB" and value is not None:
try:
@@ -184,6 +184,18 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
+def _stripNulBytesFromStr(value: Any) -> Any:
+ """psycopg2 rejects bound parameters whose Python str contains NUL (0x00).
+
+ Some extracted files (e.g. SQL dumps, mixed binary treated as text) still
+ carry those bytes; PostgreSQL TEXT could store them via other paths, but
+ the client protocol path used here cannot.
+ """
+ if isinstance(value, str) and "\x00" in value:
+ return value.replace("\x00", "")
+ return value
+
+
def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
@@ -983,7 +995,7 @@ class DatabaseConnector:
else:
value = json.dumps(value)
- values.append(value)
+ values.append(_stripNulBytesFromStr(value))
# Build INSERT/UPDATE with quoted identifiers
col_names = ", ".join([f'"{col}"' for col in columns])
diff --git a/modules/datamodels/datamodelFeatureDataSource.py b/modules/datamodels/datamodelFeatureDataSource.py
index f07a8bda..10fd76a7 100644
--- a/modules/datamodels/datamodelFeatureDataSource.py
+++ b/modules/datamodels/datamodelFeatureDataSource.py
@@ -76,6 +76,14 @@ class FeatureDataSource(PowerOnModel):
),
json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
)
+ ragIndexEnabled: Optional[bool] = Field(
+ default=None,
+ description=(
+ "Three-state RAG-indexing flag with cascade-inherit semantics. "
+ "None = inherit; True/False = explicit. Cascade-reset on parent toggle."
+ ),
+ json_schema_extra={"label": "RAG-Indexierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
+ )
neutralizeFields: Optional[List[str]] = Field(
default=None,
description="Column names whose values are replaced with placeholders before AI processing",
diff --git a/modules/demoConfigs/investorDemo2026.py b/modules/demoConfigs/investorDemo2026.py
index f8fc678f..d807921d 100644
--- a/modules/demoConfigs/investorDemo2026.py
+++ b/modules/demoConfigs/investorDemo2026.py
@@ -124,6 +124,7 @@ class InvestorDemo2026(_BaseDemoConfig):
from modules.datamodels.datamodelUam import Mandate, UserInDB
from modules.datamodels.datamodelMembership import UserMandate
+ summary["_removedMandateIds"] = []
for mandateDef in [_MANDATE_HAPPYLIFE, _MANDATE_ALPINA]:
try:
existing = db.getRecordset(Mandate, recordFilter={"name": mandateDef["name"]})
@@ -132,28 +133,36 @@ class InvestorDemo2026(_BaseDemoConfig):
self._removeMandateData(db, mid, mandateDef["label"], summary)
db.recordDelete(Mandate, mid)
summary["removed"].append(f"Mandate {mandateDef['label']} ({mid})")
+ summary["_removedMandateIds"].append({"id": mid, "mandateId": mid})
logger.info(f"Removed mandate {mandateDef['label']} ({mid})")
except Exception as e:
summary["errors"].append(f"Remove mandate {mandateDef['label']}: {e}")
+ # SAFETY: NEVER delete the user record. The user may have connections,
+ # chats, workflows, files, and other data across multiple databases.
+ # Only remove the mandate memberships that THIS demo created.
try:
existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]})
for u in existing:
uid = u.get("id")
+ removedMandateIds = {m.get("mandateId") for m in summary.get("_removedMandateIds", [])}
memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid})
for mem in memberships:
- try:
- db.recordDelete(UserMandate, mem.get("id"))
- except Exception:
- pass
- db.recordDelete(UserInDB, uid)
- summary["removed"].append(f"User {_USER['username']} ({uid})")
- logger.info(f"Removed user {_USER['username']} ({uid})")
+ if mem.get("mandateId") in removedMandateIds:
+ try:
+ db.recordDelete(UserMandate, mem.get("id"))
+ except Exception:
+ pass
+ summary["skipped"].append(
+ f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)"
+ )
+ logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only")
except Exception as e:
- summary["errors"].append(f"Remove user: {e}")
+ summary["errors"].append(f"Remove user memberships: {e}")
self._removeLanguageSet(db, "es", summary)
+ summary.pop("_removedMandateIds", None)
return summary
# ------------------------------------------------------------------
diff --git a/modules/demoConfigs/pwgDemo2026.py b/modules/demoConfigs/pwgDemo2026.py
index f0dc5e6d..4a6491a3 100644
--- a/modules/demoConfigs/pwgDemo2026.py
+++ b/modules/demoConfigs/pwgDemo2026.py
@@ -121,32 +121,39 @@ class PwgDemo2026(_BaseDemoConfig):
from modules.datamodels.datamodelMembership import UserMandate
from modules.datamodels.datamodelUam import Mandate, UserInDB
+ removedMandateIds = set()
try:
existing = db.getRecordset(Mandate, recordFilter={"name": _MANDATE_PWG["name"]})
for m in existing:
mid = m.get("id")
self._removeMandateData(db, mid, _MANDATE_PWG["label"], summary)
db.recordDelete(Mandate, mid)
+ removedMandateIds.add(mid)
summary["removed"].append(f"Mandate {_MANDATE_PWG['label']} ({mid})")
logger.info(f"Removed mandate {_MANDATE_PWG['label']} ({mid})")
except Exception as e:
summary["errors"].append(f"Remove mandate {_MANDATE_PWG['label']}: {e}")
+ # SAFETY: NEVER delete the user record. The user may have connections,
+ # chats, workflows, files, and other data across multiple databases.
+ # Only remove the mandate memberships that THIS demo created.
try:
existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]})
for u in existing:
uid = u.get("id")
memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) or []
for mem in memberships:
- try:
- db.recordDelete(UserMandate, mem.get("id"))
- except Exception:
- pass
- db.recordDelete(UserInDB, uid)
- summary["removed"].append(f"User {_USER['username']} ({uid})")
- logger.info(f"Removed user {_USER['username']} ({uid})")
+ if mem.get("mandateId") in removedMandateIds:
+ try:
+ db.recordDelete(UserMandate, mem.get("id"))
+ except Exception:
+ pass
+ summary["skipped"].append(
+ f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)"
+ )
+ logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only")
except Exception as e:
- summary["errors"].append(f"Remove user: {e}")
+ summary["errors"].append(f"Remove user memberships: {e}")
return summary
diff --git a/modules/features/workspace/datamodelFeatureWorkspace.py b/modules/features/workspace/datamodelFeatureWorkspace.py
index 4e32702c..d0ba8815 100644
--- a/modules/features/workspace/datamodelFeatureWorkspace.py
+++ b/modules/features/workspace/datamodelFeatureWorkspace.py
@@ -2,7 +2,7 @@
# All rights reserved.
"""Workspace feature data models — WorkspaceUserSettings."""
-from typing import List, Optional
+from typing import Dict, List, Optional
from pydantic import Field
from modules.datamodels.datamodelBase import PowerOnModel
from modules.shared.i18nRegistry import i18nModel
@@ -52,7 +52,7 @@ class WorkspaceUserSettings(PowerOnModel):
description="Max agent rounds override (None = instance default)",
json_schema_extra={"label": "Max. Agenten-Runden", "frontend_type": "number", "frontend_readonly": False, "frontend_required": False},
)
- requireNeutralization: bool = Field(
+ requireNeutralization: Optional[bool] = Field(
default=False,
description="Default neutralization setting for this user",
json_schema_extra={"label": "Neutralisierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
@@ -67,3 +67,8 @@ class WorkspaceUserSettings(PowerOnModel):
description="Allowed AI models (empty = all permitted)",
json_schema_extra={"label": "Erlaubte Modelle", "frontend_type": "modelMultiSelect", "frontend_readonly": False, "frontend_required": False},
)
+ uiTreeExpansion: Dict[str, List[str]] = Field(
+ default_factory=dict,
+ description="Per-tab expanded tree-node ids for the UDB / FormGeneratorTree. Key = scope name (e.g. 'sources', 'filesOwn', 'filesShared').",
+ json_schema_extra={"label": "Tree-Expand-Zustand", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False},
+ )
diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py
index 2fa788e8..5c24c113 100644
--- a/modules/features/workspace/routeFeatureWorkspace.py
+++ b/modules/features/workspace/routeFeatureWorkspace.py
@@ -1281,52 +1281,101 @@ async def listWorkspaceDataSources(
try:
from modules.datamodels.datamodelDataSource import DataSource
from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByConnection
rootIf = getRootInterface()
recordFilter: dict = {"featureInstanceId": instanceId}
if wsMandateId:
recordFilter["mandateId"] = wsMandateId
dataSources = rootIf.db.getRecordset(DataSource, recordFilter=recordFilter)
- return JSONResponse({"dataSources": dataSources or []})
+ if not dataSources:
+ return JSONResponse({"dataSources": []})
+
+ # Group by connectionId and compute effective values in aggregate mode
+ byConnection: dict = {}
+ for ds in dataSources:
+ connId = ds.get("connectionId") or ""
+ byConnection.setdefault(connId, []).append(ds)
+
+ for connDs in byConnection.values():
+ effNeutralize = buildEffectiveByConnection(connDs, "neutralize", mode="aggregate")
+ effScope = buildEffectiveByConnection(connDs, "scope", mode="aggregate")
+ effRag = buildEffectiveByConnection(connDs, "ragIndexEnabled", mode="aggregate")
+ for ds in connDs:
+ dsId = ds.get("id", "")
+ ds["effectiveNeutralize"] = effNeutralize.get(dsId, False)
+ ds["effectiveScope"] = effScope.get(dsId, "personal")
+ ds["effectiveRagIndexEnabled"] = effRag.get(dsId, False)
+
+ return JSONResponse({"dataSources": dataSources})
except Exception:
return JSONResponse({"dataSources": []})
-@router.get("/{instanceId}/connections")
+class _TreeChildrenRequest(BaseModel):
+ """Request body for the generic tree children endpoint."""
+ parents: List[Optional[str]] = Field(
+ default_factory=list,
+ description="List of parent keys to fetch children for. Use null for top-level.",
+ )
+
+
+@router.post("/{instanceId}/tree/children")
@limiter.limit("300/minute")
-async def listWorkspaceConnections(
+async def getTreeChildren(
request: Request,
instanceId: str = Path(...),
+ body: _TreeChildrenRequest = Body(...),
context: RequestContext = Depends(getRequestContext),
):
- """Return the user's active connections (UserConnections)."""
- _mandateId, _ = _validateInstanceAccess(instanceId, context)
- from modules.serviceCenter import getService
- from modules.serviceCenter.context import ServiceCenterContext
- ctx = ServiceCenterContext(
- user=context.user,
- mandate_id=_mandateId or "",
- feature_instance_id=instanceId,
+ """Generic UDB tree children resolver.
+
+ The UI sends a list of parent keys (or null for top-level). The backend
+ returns children for each requested parent, with all effective flag
+ values pre-computed. The UI builds the visible tree from the resulting
+ flat per-parent map.
+ """
+ _validateInstanceAccess(instanceId, context)
+ from modules.serviceCenter.services.serviceKnowledge._buildTree import getChildrenForParents
+
+ try:
+ nodesByParent = await getChildrenForParents(instanceId, body.parents, context)
+ except Exception as exc:
+ logger.exception("Tree children build failed: %s", exc)
+ raise HTTPException(status_code=500, detail=str(exc))
+ return JSONResponse({"nodesByParent": nodesByParent})
+
+
+class _TreeAttributesRequest(BaseModel):
+ """Request body for the attribute-refresh endpoint."""
+ keys: List[str] = Field(
+ default_factory=list,
+ description="List of node keys to fetch current attributes for.",
)
- chatService = getService("chat", ctx)
- connections = chatService.getUserConnections()
- items = []
- for c in connections or []:
- conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {})
- authority = conn.get("authority")
- if hasattr(authority, "value"):
- authority = authority.value
- status = conn.get("status")
- if hasattr(status, "value"):
- status = status.value
- items.append({
- "id": conn.get("id"),
- "authority": authority,
- "externalUsername": conn.get("externalUsername"),
- "externalEmail": conn.get("externalEmail"),
- "status": status,
- "knowledgeIngestionEnabled": bool(conn.get("knowledgeIngestionEnabled")),
- })
- return JSONResponse({"connections": items})
+
+
+@router.post("/{instanceId}/tree/attributes")
+@limiter.limit("300/minute")
+async def getTreeAttributes(
+ request: Request,
+ instanceId: str = Path(...),
+ body: _TreeAttributesRequest = Body(...),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Return current effective attribute values (neutralize, scope,
+ ragIndexEnabled) for a list of node keys. Used after a toggle action
+ to refresh only the visible nodes without reloading tree structure."""
+ _validateInstanceAccess(instanceId, context)
+ from modules.serviceCenter.services.serviceKnowledge._buildTree import getAttributesForKeys
+
+ if len(body.keys) > 500:
+ raise HTTPException(status_code=400, detail="Max 500 keys per request")
+
+ try:
+ attrs = await getAttributesForKeys(instanceId, body.keys, context)
+ except Exception as exc:
+ logger.exception("Tree attributes failed: %s", exc)
+ raise HTTPException(status_code=500, detail=str(exc))
+ return JSONResponse({"attributes": attrs})
class CreateDataSourceRequest(BaseModel):
@@ -1391,303 +1440,6 @@ async def deleteWorkspaceDataSource(
# ---- Feature Connections & Feature Data Sources ----
-@router.get("/{instanceId}/feature-connections")
-@limiter.limit("120/minute")
-async def listFeatureConnections(
- request: Request,
- instanceId: str = Path(...),
- context: RequestContext = Depends(getRequestContext),
-):
- """List feature instances the user has access to, scoped to the workspace mandate."""
- wsMandateId, _ = _validateInstanceAccess(instanceId, context)
- from modules.interfaces.interfaceDbApp import getRootInterface
- from modules.security.rbacCatalog import getCatalogService
- from modules.datamodels.datamodelUam import Mandate
-
- rootIf = getRootInterface()
- userId = str(context.user.id)
-
- catalog = getCatalogService()
- featureCodesWithData = catalog.getFeaturesWithDataObjects()
-
- userMandates = rootIf.getUserMandates(userId)
- if not userMandates:
- return JSONResponse({"featureConnectionsByMandate": []})
-
- allowedMandateIds = {um.mandateId for um in userMandates}
- if wsMandateId and wsMandateId in allowedMandateIds:
- allowedMandateIds = {wsMandateId}
-
- mandateLabels: dict = {}
- for um in userMandates:
- if um.mandateId not in allowedMandateIds:
- continue
- try:
- rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId})
- if rows:
- m = rows[0]
- mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId
- except Exception:
- mandateLabels[um.mandateId] = um.mandateId
-
- byMandate: dict = {}
- seenIds: set = set()
- for um in userMandates:
- if um.mandateId not in allowedMandateIds:
- continue
- allInstances = rootIf.getFeatureInstancesByMandate(um.mandateId)
- for inst in allInstances:
- if inst.id in seenIds:
- continue
- seenIds.add(inst.id)
- if not inst.enabled:
- continue
- if inst.featureCode not in featureCodesWithData:
- continue
- featureAccess = rootIf.getFeatureAccess(userId, inst.id)
- if not featureAccess or not featureAccess.enabled:
- continue
-
- featureDef = catalog.getFeatureDefinition(inst.featureCode) or {}
- dataObjects = catalog.getDataObjects(inst.featureCode)
- label = inst.label or inst.featureCode
- mid = inst.mandateId
- connItem = {
- "featureInstanceId": inst.id,
- "featureCode": inst.featureCode,
- "mandateId": mid,
- "label": label,
- "icon": featureDef.get("icon", "mdi-database"),
- "tableCount": len(dataObjects),
- }
- if mid not in byMandate:
- byMandate[mid] = []
- byMandate[mid].append(connItem)
-
- def _sortKeyLabel(x: dict) -> str:
- return (x.get("label") or "").lower()
-
- groups = []
- for mid in sorted(byMandate.keys(), key=lambda m: (mandateLabels.get(m, m) or "").lower()):
- conns = sorted(byMandate[mid], key=_sortKeyLabel)
- groups.append({
- "mandateId": mid,
- "mandateLabel": mandateLabels.get(mid, mid),
- "featureConnections": conns,
- })
-
- return JSONResponse({"featureConnectionsByMandate": groups})
-
-
-@router.get("/{instanceId}/feature-connections/{fiId}/tables")
-@limiter.limit("120/minute")
-async def listFeatureConnectionTables(
- request: Request,
- instanceId: str = Path(...),
- fiId: str = Path(..., description="Feature instance ID"),
- context: RequestContext = Depends(getRequestContext),
-):
- """List data tables (DATA_OBJECTS) for a feature instance, filtered by RBAC."""
- wsMandateId, _ = _validateInstanceAccess(instanceId, context)
- from modules.interfaces.interfaceDbApp import getRootInterface
- from modules.security.rbacCatalog import getCatalogService
-
- rootIf = getRootInterface()
- inst = rootIf.getFeatureInstance(fiId)
- if not inst:
- raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found"))
-
- mandateId = str(inst.mandateId) if inst.mandateId else None
- if wsMandateId and mandateId and mandateId != wsMandateId:
- raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate"))
- catalog = getCatalogService()
-
- try:
- from modules.security.rbac import RbacClass
- from modules.security.rootAccess import getRootDbAppConnector
- dbApp = getRootDbAppConnector()
- rbac = RbacClass(dbApp, dbApp=dbApp)
- accessible = catalog.getAccessibleDataObjects(
- featureCode=inst.featureCode,
- rbacInstance=rbac,
- user=context.user,
- mandateId=mandateId or "",
- featureInstanceId=fiId,
- )
- except Exception:
- accessible = catalog.getDataObjects(inst.featureCode)
-
- accessibleKeys = {obj.get("objectKey", "") for obj in accessible}
- referencedGroups = set()
- for obj in accessible:
- meta = obj.get("meta", {})
- if meta.get("wildcard") or meta.get("isGroup"):
- continue
- if meta.get("group"):
- referencedGroups.add(meta["group"])
-
- tables = []
- for obj in catalog.getDataObjects(inst.featureCode):
- meta = obj.get("meta", {})
- if meta.get("wildcard"):
- continue
- objectKey = obj.get("objectKey", "")
- if meta.get("isGroup"):
- # Groups are metadata-only; include if at least one child is accessible
- # (regardless of whether the group itself was RBAC-granted).
- if objectKey not in referencedGroups:
- continue
- else:
- if objectKey not in accessibleKeys:
- continue
- node = {
- "objectKey": objectKey,
- "tableName": meta.get("table", ""),
- "label": resolveText(obj.get("label", "")),
- "fields": meta.get("fields", []),
- "isParent": bool(meta.get("isParent", False)),
- "parentTable": meta.get("parentTable") or None,
- "parentKey": meta.get("parentKey") or None,
- "displayFields": meta.get("displayFields", []),
- "isGroup": bool(meta.get("isGroup", False)),
- "group": meta.get("group") or None,
- }
- tables.append(node)
-
- return JSONResponse({"tables": tables})
-
-
-@router.get("/{instanceId}/feature-connections/{fiId}/parent-objects/{tableName}")
-@limiter.limit("120/minute")
-async def listParentObjects(
- request: Request,
- instanceId: str = Path(...),
- fiId: str = Path(..., description="Feature instance ID"),
- tableName: str = Path(..., description="Parent table name from DATA_OBJECTS"),
- parentKey: Optional[str] = Query(None, description="Optional FK column name to filter by ancestor record (nested parent rendering)"),
- parentValue: Optional[str] = Query(None, description="Optional FK value matching parentKey to filter children of a specific ancestor record"),
- context: RequestContext = Depends(getRequestContext),
-):
- """List records from a parent table so the user can pick a specific record to scope data.
-
- When parentKey + parentValue are provided, results are additionally filtered by that FK,
- enabling nested record hierarchies (e.g. Sessions OF Context X).
- """
- wsMandateId, _ = _validateInstanceAccess(instanceId, context)
- from modules.interfaces.interfaceDbApp import getRootInterface
- from modules.security.rbacCatalog import getCatalogService
-
- rootIf = getRootInterface()
- inst = rootIf.getFeatureInstance(fiId)
- if not inst:
- raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found"))
-
- featureCode = inst.featureCode
- mandateId = str(inst.mandateId) if inst.mandateId else ""
- if wsMandateId and mandateId and mandateId != wsMandateId:
- raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate"))
- catalog = getCatalogService()
-
- parentObj = None
- for obj in catalog.getDataObjects(featureCode):
- meta = obj.get("meta", {})
- if meta.get("table") == tableName and meta.get("isParent"):
- parentObj = obj
- break
- if not parentObj:
- raise HTTPException(status_code=400, detail=f"Table '{tableName}' is not a registered parent table")
-
- displayFields = parentObj["meta"].get("displayFields", [])
- selectCols = ', '.join(f'"{f}"' for f in (["id"] + displayFields)) if displayFields else "*"
-
- from modules.connectors.connectorDbPostgre import DatabaseConnector
- from modules.shared.configuration import APP_CONFIG
- featureDbName = f"poweron_{featureCode.lower()}"
- featureDbConn = None
- try:
- featureDbConn = DatabaseConnector(
- dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
- dbDatabase=featureDbName,
- dbUser=APP_CONFIG.get("DB_USER"),
- dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
- dbPort=int(APP_CONFIG.get("DB_PORT", 5432)),
- userId=str(context.user.id),
- )
- conn = featureDbConn.connection
- with conn.cursor() as cur:
- cur.execute(
- "SELECT column_name FROM information_schema.columns "
- "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
- "AND column_name IN ('featureInstanceId', 'instanceId')",
- [tableName],
- )
- instanceCols = [row["column_name"] for row in cur.fetchall()]
- instanceCol = "featureInstanceId" if "featureInstanceId" in instanceCols else "instanceId"
-
- cur.execute(
- "SELECT column_name FROM information_schema.columns "
- "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
- "AND column_name = 'userId'",
- [tableName],
- )
- hasUserId = cur.rowcount > 0
-
- sql = (
- f'SELECT {selectCols} FROM "{tableName}" '
- f'WHERE "{instanceCol}" = %s'
- )
- params = [fiId]
- if mandateId:
- sql += ' AND "mandateId" = %s'
- params.append(mandateId)
- if hasUserId:
- sql += ' AND "userId" = %s'
- params.append(str(context.user.id))
-
- if parentKey and parentValue:
- cur.execute(
- "SELECT 1 FROM information_schema.columns "
- "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
- "AND column_name = %s",
- [tableName, parentKey],
- )
- if cur.rowcount > 0:
- sql += f' AND "{parentKey}" = %s'
- params.append(parentValue)
- else:
- logger.warning(
- f"listParentObjects({tableName}): ignoring parentKey '{parentKey}' (column does not exist)"
- )
-
- sql += ' ORDER BY "id" DESC LIMIT 100'
- cur.execute(sql, params)
- rows = []
- for row in cur.fetchall():
- r = dict(row)
- for k, v in r.items():
- if hasattr(v, "isoformat"):
- r[k] = v.isoformat()
- elif isinstance(v, (bytes, bytearray)):
- r[k] = f""
- displayParts = [str(r.get(f, "")) for f in displayFields if r.get(f) is not None]
- rows.append({
- "id": r.get("id", ""),
- "displayLabel": " | ".join(displayParts) if displayParts else r.get("id", ""),
- "fields": {f: r.get(f) for f in displayFields},
- })
- except Exception as e:
- logger.error(f"listParentObjects({tableName}) failed: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"Failed to list parent objects: {e}")
- finally:
- if featureDbConn:
- try:
- featureDbConn.close()
- except Exception:
- pass
-
- return JSONResponse({"parentObjects": rows})
-
-
class CreateFeatureDataSourceRequest(BaseModel):
"""Request body for adding a feature table as data source."""
featureInstanceId: str = Field(description="Feature instance ID")
@@ -1706,16 +1458,35 @@ async def createFeatureDataSource(
body: CreateFeatureDataSourceRequest = Body(...),
context: RequestContext = Depends(getRequestContext),
):
- """Create a FeatureDataSource for this workspace instance."""
+ """Create a FeatureDataSource for this workspace instance.
+
+ The FDS lives under the WORKSPACE's mandate (not the feature's): that
+ matches how the tree (`allFds = recordset where workspaceInstanceId =
+ instanceId`) and the PATCH endpoints scope these records — by workspace,
+ not by feature mandate. The user can legitimately reference a feature
+ from another mandate they have access to (via the UDB mandate-group
+ nodes), and a hard cross-mandate block here would silently 403 those
+ toggles. Access to the referenced feature is verified by the user's
+ `FeatureAccess` and the existing tree-children RBAC, which run before
+ the user can ever click on this node.
+ """
wsMandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
rootIf = getRootInterface()
- inst = rootIf.getFeatureInstance(body.featureInstanceId)
- mandateId = str(inst.mandateId) if inst else (str(context.mandateId) if context.mandateId else "")
- if wsMandateId and mandateId and mandateId != wsMandateId:
- raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate"))
+ if not rootIf.getFeatureAccess(str(context.user.id), body.featureInstanceId):
+ raise HTTPException(status_code=403, detail=routeApiMsg("Access denied to this feature instance"))
+
+ existing = rootIf.db.getRecordset(FeatureDataSource, recordFilter={
+ "workspaceInstanceId": instanceId,
+ "featureInstanceId": body.featureInstanceId,
+ "tableName": body.tableName,
+ }) or []
+ targetFilter = body.recordFilter or None
+ for rec in existing:
+ if (rec.get("recordFilter") or None) == targetFilter:
+ return JSONResponse(rec)
fds = FeatureDataSource(
featureInstanceId=body.featureInstanceId,
@@ -1723,7 +1494,7 @@ async def createFeatureDataSource(
tableName=body.tableName,
objectKey=body.objectKey,
label=body.label,
- mandateId=mandateId,
+ mandateId=wsMandateId or "",
userId=str(context.user.id),
workspaceInstanceId=instanceId,
recordFilter=body.recordFilter,
@@ -1743,13 +1514,26 @@ async def listFeatureDataSources(
wsMandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByWorkspaceFds
rootIf = getRootInterface()
recordFilter: dict = {"workspaceInstanceId": instanceId}
if wsMandateId:
recordFilter["mandateId"] = wsMandateId
records = rootIf.db.getRecordset(FeatureDataSource, recordFilter=recordFilter)
- return JSONResponse({"featureDataSources": records or []})
+ if not records:
+ return JSONResponse({"featureDataSources": []})
+
+ effNeutralize = buildEffectiveByWorkspaceFds(records, "neutralize", mode="aggregate")
+ effScope = buildEffectiveByWorkspaceFds(records, "scope", mode="aggregate")
+ effRag = buildEffectiveByWorkspaceFds(records, "ragIndexEnabled", mode="aggregate")
+ for fds in records:
+ fdsId = fds.get("id", "")
+ fds["effectiveNeutralize"] = effNeutralize.get(fdsId, False)
+ fds["effectiveScope"] = effScope.get(fdsId, "personal")
+ fds["effectiveRagIndexEnabled"] = effRag.get(fdsId, False)
+
+ return JSONResponse({"featureDataSources": records})
@router.delete("/{instanceId}/feature-datasources/{featureDataSourceId}")
@@ -1770,112 +1554,6 @@ async def deleteFeatureDataSource(
return JSONResponse({"success": True})
-@router.get("/{instanceId}/connections/{connectionId}/services")
-@limiter.limit("120/minute")
-async def listConnectionServices(
- request: Request,
- instanceId: str = Path(...),
- connectionId: str = Path(...),
- context: RequestContext = Depends(getRequestContext),
-):
- """Return the available services for a specific UserConnection."""
- _mandateId, _ = _validateInstanceAccess(instanceId, context)
- try:
- from modules.connectors.connectorResolver import ConnectorResolver
- from modules.serviceCenter import getService as getSvc
- from modules.serviceCenter.context import ServiceCenterContext
- ctx = ServiceCenterContext(
- user=context.user,
- mandate_id=_mandateId or "",
- feature_instance_id=instanceId,
- )
- chatService = getSvc("chat", ctx)
- securityService = getSvc("security", ctx)
- dbInterface = _buildResolverDbInterface(chatService)
- resolver = ConnectorResolver(securityService, dbInterface)
- provider = await resolver.resolve(connectionId)
- services = provider.getAvailableServices()
- _serviceLabels = {
- "sharepoint": "SharePoint",
- "outlook": "Outlook",
- "teams": "Teams",
- "onedrive": "OneDrive",
- "drive": "Google Drive",
- "gmail": "Gmail",
- "files": "Files (FTP)",
- "kdrive": "kDrive",
- "calendar": "Calendar",
- "contact": "Contacts",
- }
- _serviceIcons = {
- "sharepoint": "sharepoint",
- "outlook": "mail",
- "teams": "chat",
- "onedrive": "cloud",
- "drive": "cloud",
- "gmail": "mail",
- "files": "folder",
- "kdrive": "cloud",
- "calendar": "calendar",
- "contact": "contact",
- }
- items = [
- {
- "service": s,
- "label": _serviceLabels.get(s, s),
- "icon": _serviceIcons.get(s, "folder"),
- }
- for s in services
- ]
- return JSONResponse({"services": items})
- except Exception as e:
- logger.error(f"Error listing services for connection {connectionId}: {e}")
- return JSONResponse({"services": [], "error": str(e)}, status_code=400)
-
-
-@router.get("/{instanceId}/connections/{connectionId}/browse")
-@limiter.limit("300/minute")
-async def browseConnectionService(
- request: Request,
- instanceId: str = Path(...),
- connectionId: str = Path(...),
- service: str = Query(..., description="Service name (e.g. sharepoint, onedrive, outlook)"),
- path: str = Query("/", description="Path within the service to browse"),
- context: RequestContext = Depends(getRequestContext),
-):
- """Browse folders/items within a connection's service at a given path."""
- _mandateId, _ = _validateInstanceAccess(instanceId, context)
- try:
- from modules.connectors.connectorResolver import ConnectorResolver
- from modules.serviceCenter import getService as getSvc
- from modules.serviceCenter.context import ServiceCenterContext
- ctx = ServiceCenterContext(
- user=context.user,
- mandate_id=_mandateId or "",
- feature_instance_id=instanceId,
- )
- chatService = getSvc("chat", ctx)
- securityService = getSvc("security", ctx)
- dbInterface = _buildResolverDbInterface(chatService)
- resolver = ConnectorResolver(securityService, dbInterface)
- adapter = await resolver.resolveService(connectionId, service)
- entries = await adapter.browse(path, filter=None)
- items = []
- for entry in (entries or []):
- items.append({
- "name": entry.name,
- "path": entry.path,
- "isFolder": entry.isFolder,
- "size": entry.size,
- "mimeType": entry.mimeType,
- "metadata": entry.metadata if hasattr(entry, "metadata") else {},
- })
- return JSONResponse({"items": items, "path": path, "service": service})
- except Exception as e:
- logger.error(f"Error browsing {service} for connection {connectionId} at '{path}': {e}")
- return JSONResponse({"items": [], "error": str(e)}, status_code=400)
-
-
# ---------------------------------------------------------------------------
# Voice endpoints
# ---------------------------------------------------------------------------
@@ -2191,6 +1869,71 @@ async def putWorkspaceUserSettings(
})
+# =========================================================================
+# Per-user UI state: tree expand/collapse (UDB + FilesTab)
+# Persisted on WorkspaceUserSettings.uiTreeExpansion as a {scope: [ids]} map.
+# Each FE tab uses its own scope key so collapse-state for one tab doesn't
+# bleed into another.
+
+@router.get("/{instanceId}/ui-tree-expansion/{scope}")
+@limiter.limit("300/minute")
+async def getUiTreeExpansion(
+ request: Request,
+ instanceId: str = Path(...),
+ scope: str = Path(..., description="UI scope key, e.g. 'sources', 'filesOwn', 'filesShared'"),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Return the expanded tree-node ids for the current user + scope.
+
+ Returns `null` when the user has never persisted a state for this scope
+ (lets the FE fall back to backend `defaultExpanded` hints). Returns `[]`
+ when the user actively collapsed everything.
+ """
+ _validateInstanceAccess(instanceId, context)
+ wsInterface = _getWorkspaceInterface(context, instanceId)
+ settings = wsInterface.getWorkspaceUserSettings(str(context.user.id))
+ expansion = (settings.uiTreeExpansion if settings else {}) or {}
+ if scope not in expansion:
+ return JSONResponse({"expandedNodes": None})
+ return JSONResponse({"expandedNodes": list(expansion.get(scope) or [])})
+
+
+@router.put("/{instanceId}/ui-tree-expansion/{scope}")
+@limiter.limit("300/minute")
+async def putUiTreeExpansion(
+ request: Request,
+ instanceId: str = Path(...),
+ scope: str = Path(...),
+ body: dict = Body(...),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Replace the expanded-node list for one scope.
+
+ Body: `{"expandedNodes": List[str]}`. Empty list = explicit collapse-all.
+ """
+ _validateInstanceAccess(instanceId, context)
+ wsInterface = _getWorkspaceInterface(context, instanceId)
+ userId = str(context.user.id)
+ nodes = body.get("expandedNodes")
+ if not isinstance(nodes, list):
+ raise HTTPException(status_code=400, detail=routeApiMsg("expandedNodes must be a list"))
+ cleaned = [str(n) for n in nodes if isinstance(n, (str, int))]
+
+ existing = wsInterface.getWorkspaceUserSettings(userId)
+ existingMap: Dict[str, List[str]] = (existing.uiTreeExpansion if existing else {}) or {}
+ existingMap = dict(existingMap)
+ existingMap[scope] = cleaned
+
+ data = {
+ "userId": userId,
+ "mandateId": str(context.mandateId) if context.mandateId else "",
+ "featureInstanceId": instanceId,
+ "uiTreeExpansion": existingMap,
+ }
+ wsInterface.saveWorkspaceUserSettings(data)
+ return JSONResponse({"expandedNodes": cleaned})
+
+
# =========================================================================
# RAG / Knowledge — anonymised instance statistics (presentation / KPIs)
diff --git a/modules/routes/routeAdminDemoConfig.py b/modules/routes/routeAdminDemoConfig.py
index db37e775..0673c299 100644
--- a/modules/routes/routeAdminDemoConfig.py
+++ b/modules/routes/routeAdminDemoConfig.py
@@ -68,9 +68,19 @@ def removeDemoConfig(
request: Request,
currentUser: User = Depends(requirePlatformAdmin),
) -> dict:
- """Remove all data created by a demo configuration."""
+ """Remove all data created by a demo configuration.
+
+ Requires X-Confirm-Destructive: true header as safety guard.
+ """
from modules.demoConfigs import getDemoConfigByCode
+ confirmHeader = request.headers.get("X-Confirm-Destructive", "").lower()
+ if confirmHeader != "true":
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Destructive operation requires header X-Confirm-Destructive: true",
+ )
+
config = getDemoConfigByCode(code)
if not config:
raise HTTPException(
@@ -79,7 +89,7 @@ def removeDemoConfig(
)
db = getRootDbAppConnector()
- logger.info(f"Removing demo config '{code}' (user: {currentUser.username})")
+ logger.info(f"Removing demo config '{code}' (user: {currentUser.username}, confirmed)")
summary = config.remove(db)
logger.info(f"Demo config '{code}' removed: {summary}")
diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py
index e2b08461..2bc48042 100644
--- a/modules/routes/routeDataConnections.py
+++ b/modules/routes/routeDataConnections.py
@@ -778,7 +778,12 @@ async def _updateKnowledgeConsent(
cancelled = cancelJobsByConnection(connectionId)
else:
from modules.datamodels.datamodelDataSource import DataSource
- dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True})
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
+ allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
+ dataSources = [
+ ds for ds in (allConnDs or [])
+ if getEffectiveFlag(ds, "ragIndexEnabled", allConnDs, mode="walk") is True
+ ]
if dataSources:
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "")
diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py
index b22dacae..4bcbcf8f 100644
--- a/modules/routes/routeDataFiles.py
+++ b/modules/routes/routeDataFiles.py
@@ -211,7 +211,7 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man
from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob
- await knowledgeService.requestIngestion(
+ handle = await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="file",
sourceId=fileId,
@@ -229,7 +229,10 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man
# Re-acquire interface after await to avoid stale user context from the singleton
mgmtInterface = interfaceDbManagement.getInterface(user)
mgmtInterface.updateFile(fileId, {"status": "active"})
- logger.info(f"Auto-index complete for file {fileId} ({fileName})")
+ if handle.status == "failed":
+ logger.warning(f"Auto-index ingestion failed for file {fileId} ({fileName}): {handle.error}")
+ else:
+ logger.info(f"Auto-index complete for file {fileId} ({fileName})")
except Exception as e:
logger.error(f"Auto-index failed for file {fileId}: {e}", exc_info=True)
@@ -256,6 +259,24 @@ router = APIRouter(
)
+def _getInterfaceForOwnedItem(currentUser: User, context, itemId: str, modelClass) -> Any:
+ """Create a management interface scoped to the item's own context.
+ Looks up the item by ID (unscoped) to resolve its mandateId/featureInstanceId,
+ then creates the interface with THAT context. This ensures toggle operations
+ work regardless of which page the user is on."""
+ unscoped = interfaceDbManagement.getInterface(currentUser)
+ record = unscoped.db.getRecord(modelClass, itemId)
+ if not record:
+ raise interfaceDbManagement.FileNotFoundError(f"Item {itemId} not found")
+ itemMandateId = record.get("mandateId") if isinstance(record, dict) else getattr(record, "mandateId", None)
+ itemInstanceId = record.get("featureInstanceId") if isinstance(record, dict) else getattr(record, "featureInstanceId", None)
+ return interfaceDbManagement.getInterface(
+ currentUser,
+ mandateId=str(itemMandateId) if itemMandateId else None,
+ featureInstanceId=str(itemInstanceId) if itemInstanceId else None,
+ )
+
+
@router.get("/folders/tree")
@limiter.limit("120/minute")
def get_folder_tree(
@@ -272,10 +293,12 @@ def get_folder_tree(
)
o = (owner or "me").strip().lower()
if o == "me":
- return managementInterface.getOwnFolderTree()
- if o == "shared":
- return managementInterface.getSharedFolderTree()
- raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'")
+ folders = managementInterface.getOwnFolderTree()
+ elif o == "shared":
+ folders = managementInterface.getSharedFolderTree()
+ else:
+ raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'")
+ return folders
except HTTPException:
raise
except Exception as e:
@@ -283,6 +306,185 @@ def get_folder_tree(
raise HTTPException(status_code=500, detail=str(e))
+@router.post("/attributes")
+@limiter.limit("120/minute")
+def getAttributesForIds(
+ request: Request,
+ body: Dict[str, Any] = Body(...),
+ currentUser: User = Depends(getCurrentUser),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Return current attribute values (neutralize, scope, ragIndexEnabled) for
+ a list of node IDs. For folder IDs, computes 'mixed' by checking direct
+ children. The frontend sends this after every toggle to refresh visible
+ nodes without reloading the tree structure."""
+ ids = body.get("ids", [])
+ if not isinstance(ids, list) or len(ids) == 0:
+ return {}
+ if len(ids) > 500:
+ raise HTTPException(status_code=400, detail="Max 500 IDs per request")
+
+ try:
+ managementInterface = interfaceDbManagement.getInterface(
+ currentUser,
+ mandateId=str(context.mandateId) if context.mandateId else None,
+ featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
+ )
+ db = managementInterface.db
+ userId = str(currentUser.id)
+
+ allFolders = db.getRecordset(FileFolder, recordFilter={"sysCreatedBy": userId}) or []
+ allFiles = db.getRecordset(FileItem, recordFilter={"sysCreatedBy": userId}) or []
+
+ folderById = {f["id"]: f for f in allFolders}
+ fileById = {f["id"]: f for f in allFiles}
+
+ logger.info(
+ "getAttributesForIds: %d ids requested, %d folders found, %d files found",
+ len(ids), len(allFolders), len(allFiles),
+ )
+
+ result: Dict[str, Dict[str, Any]] = {}
+
+ for nodeId in ids:
+ if nodeId.startswith("__filesRoot:"):
+ attrs = _computeSyntheticRootAttrs(allFolders, allFiles)
+ result[nodeId] = attrs
+ elif nodeId in folderById:
+ folder = folderById[nodeId]
+ attrs = _computeFolderAttrs(folder, allFolders, allFiles)
+ result[nodeId] = attrs
+ elif nodeId in fileById:
+ f = fileById[nodeId]
+ result[nodeId] = {
+ "neutralize": bool(f.get("neutralize", False)),
+ "scope": f.get("scope", "personal"),
+ }
+ else:
+ logger.debug("getAttributesForIds: unknown id=%s", nodeId)
+
+ logger.info("getAttributesForIds: returning %d entries", len(result))
+ return result
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"getAttributesForIds error: {e}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+def _computeFolderAttrs(
+ folder: Dict[str, Any],
+ allFolders: List[Dict[str, Any]],
+ allFiles: List[Dict[str, Any]],
+) -> Dict[str, Any]:
+ """Compute attributes for a folder. Recursively checks the entire subtree:
+ if ANY descendant at any depth has a different value, the folder shows 'mixed'.
+ This propagates up through all ancestor levels."""
+ fid = folder["id"]
+ neutralizeResult = _effectiveNeutralize(fid, allFolders, allFiles)
+ scopeResult = _effectiveScope(fid, allFolders, allFiles)
+ return {"neutralize": neutralizeResult, "scope": scopeResult}
+
+
+def _effectiveNeutralize(
+ folderId: str,
+ allFolders: List[Dict[str, Any]],
+ allFiles: List[Dict[str, Any]],
+) -> Any:
+ """Recursively compute effective neutralize for a folder.
+ Returns 'mixed' if any descendants diverge, otherwise the folder's own value."""
+ childFolders = [f for f in allFolders if f.get("parentId") == folderId]
+ childFiles = [f for f in allFiles if f.get("folderId") == folderId]
+
+ if not childFolders and not childFiles:
+ folder = next((f for f in allFolders if f["id"] == folderId), None)
+ return bool(folder.get("neutralize", False)) if folder else False
+
+ childVals = set()
+ for cf in childFolders:
+ effective = _effectiveNeutralize(cf["id"], allFolders, allFiles)
+ if effective == "mixed":
+ return "mixed"
+ childVals.add(effective)
+ for cf in childFiles:
+ childVals.add(bool(cf.get("neutralize", False)))
+
+ if len(childVals) > 1:
+ return "mixed"
+ if not childVals:
+ folder = next((f for f in allFolders if f["id"] == folderId), None)
+ return bool(folder.get("neutralize", False)) if folder else False
+ return childVals.pop()
+
+
+def _effectiveScope(
+ folderId: str,
+ allFolders: List[Dict[str, Any]],
+ allFiles: List[Dict[str, Any]],
+) -> Any:
+ """Recursively compute effective scope for a folder.
+ Returns 'mixed' if any descendants diverge, otherwise the folder's own value."""
+ childFolders = [f for f in allFolders if f.get("parentId") == folderId]
+ childFiles = [f for f in allFiles if f.get("folderId") == folderId]
+
+ if not childFolders and not childFiles:
+ folder = next((f for f in allFolders if f["id"] == folderId), None)
+ return folder.get("scope", "personal") if folder else "personal"
+
+ childVals = set()
+ for cf in childFolders:
+ effective = _effectiveScope(cf["id"], allFolders, allFiles)
+ if effective == "mixed":
+ return "mixed"
+ childVals.add(effective)
+ for cf in childFiles:
+ childVals.add(cf.get("scope", "personal"))
+
+ if len(childVals) > 1:
+ return "mixed"
+ if not childVals:
+ folder = next((f for f in allFolders if f["id"] == folderId), None)
+ return folder.get("scope", "personal") if folder else "personal"
+ return childVals.pop()
+
+
+def _computeSyntheticRootAttrs(
+ allFolders: List[Dict[str, Any]],
+ allFiles: List[Dict[str, Any]],
+) -> Dict[str, Any]:
+ """Compute attributes for the synthetic root by recursively checking the
+ entire tree. If ANY item at any depth diverges, root shows 'mixed'."""
+ topFolders = [f for f in allFolders if not f.get("parentId")]
+ topFiles = [f for f in allFiles if not f.get("folderId")]
+
+ neutralizeVals = set()
+ scopeVals = set()
+ for cf in topFolders:
+ nEff = _effectiveNeutralize(cf["id"], allFolders, allFiles)
+ if nEff == "mixed":
+ neutralizeVals.add(True)
+ neutralizeVals.add(False)
+ else:
+ neutralizeVals.add(nEff)
+ sEff = _effectiveScope(cf["id"], allFolders, allFiles)
+ if sEff == "mixed":
+ scopeVals.add("__mixed_a__")
+ scopeVals.add("__mixed_b__")
+ else:
+ scopeVals.add(sEff)
+ for cf in topFiles:
+ neutralizeVals.add(bool(cf.get("neutralize", False)))
+ scopeVals.add(cf.get("scope", "personal"))
+
+ if not neutralizeVals and not scopeVals:
+ return {"neutralize": False, "scope": "personal"}
+
+ return {
+ "neutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False),
+ "scope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"),
+ }
+
+
@router.post("/folders", status_code=status.HTTP_201_CREATED)
@limiter.limit("30/minute")
def create_folder(
@@ -353,7 +555,12 @@ def move_folder(
context: RequestContext = Depends(getRequestContext),
):
try:
+ # FE may send `parentId` or `targetParentId`. Accept both so the
+ # FormGeneratorTree generic `provider.moveNodes(targetParentId)` API
+ # remains consistent with the file-move (PUT /api/files/{id}) shape.
newParentId = body.get("parentId")
+ if newParentId is None:
+ newParentId = body.get("targetParentId")
managementInterface = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
@@ -414,11 +621,7 @@ def patch_folder_scope(
if not scope:
raise HTTPException(status_code=400, detail="scope is required")
cascadeToFiles = body.get("cascadeChildren", body.get("cascadeToFiles", False))
- managementInterface = interfaceDbManagement.getInterface(
- currentUser,
- mandateId=str(context.mandateId) if context.mandateId else None,
- featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
- )
+ managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder)
return managementInterface.patchFolderScope(folderId, scope, cascadeToFiles)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -446,11 +649,7 @@ def patch_folder_neutralize(
neutralize = body.get("neutralize")
if neutralize is None:
raise HTTPException(status_code=400, detail="neutralize is required")
- managementInterface = interfaceDbManagement.getInterface(
- currentUser,
- mandateId=str(context.mandateId) if context.mandateId else None,
- featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
- )
+ managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder)
return managementInterface.patchFolderNeutralize(folderId, bool(neutralize))
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@@ -1031,11 +1230,7 @@ def updateFileScope(
if scope == "global" and not context.isSysAdmin:
raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope"))
- managementInterface = interfaceDbManagement.getInterface(
- context.user,
- mandateId=str(context.mandateId) if context.mandateId else None,
- featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
- )
+ managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem)
managementInterface.updateFile(fileId, {"scope": scope})
@@ -1093,11 +1288,7 @@ def updateFileNeutralize(
fails the file simply has no index — no un-neutralized data can leak.
"""
try:
- managementInterface = interfaceDbManagement.getInterface(
- context.user,
- mandateId=str(context.mandateId) if context.mandateId else None,
- featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
- )
+ managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem)
managementInterface.updateFile(fileId, {"neutralize": neutralize})
@@ -1212,7 +1403,8 @@ def update_file(
request: Request,
fileId: str = Path(..., description="ID of the file to update"),
file_info: Dict[str, Any] = Body(...),
- currentUser: User = Depends(getCurrentUser)
+ currentUser: User = Depends(getCurrentUser),
+ context: RequestContext = Depends(getRequestContext),
) -> FileItem:
"""Update file info"""
try:
@@ -1221,7 +1413,11 @@ def update_file(
if not safeData:
raise HTTPException(status_code=400, detail=routeApiMsg("No editable fields provided"))
- managementInterface = interfaceDbManagement.getInterface(currentUser)
+ managementInterface = interfaceDbManagement.getInterface(
+ currentUser,
+ mandateId=str(context.mandateId) if context.mandateId else None,
+ featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
+ )
file = managementInterface.getFile(fileId)
if not file:
@@ -1267,10 +1463,15 @@ def update_file(
def delete_file(
request: Request,
fileId: str = Path(..., description="ID of the file to delete"),
- currentUser: User = Depends(getCurrentUser)
+ currentUser: User = Depends(getCurrentUser),
+ context: RequestContext = Depends(getRequestContext),
) -> Dict[str, Any]:
"""Delete a file"""
- managementInterface = interfaceDbManagement.getInterface(currentUser)
+ managementInterface = interfaceDbManagement.getInterface(
+ currentUser,
+ mandateId=str(context.mandateId) if context.mandateId else None,
+ featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
+ )
# Check if the file exists
existingFile = managementInterface.getFile(fileId)
diff --git a/modules/routes/routeDataSources.py b/modules/routes/routeDataSources.py
index 5dec19c8..b2f919b7 100644
--- a/modules/routes/routeDataSources.py
+++ b/modules/routes/routeDataSources.py
@@ -43,6 +43,49 @@ def _ensureConnectionKnowledgeFlag(rootIf, connectionId: str) -> None:
except Exception as e:
logger.warning("Could not auto-enable knowledgeIngestionEnabled for connection %s: %s", connectionId, e)
+def _computeOwnEffective(rootIf, rec, model, sourceId: str, flag: str) -> Any:
+ """Re-load the record after modification and compute its aggregate effective value."""
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ getEffectiveFlag, getEffectiveFlagFds,
+ )
+ freshRec = rootIf.db.getRecord(model, sourceId)
+ if not freshRec:
+ return None
+ if model is DataSource:
+ connectionId = freshRec.get("connectionId", "")
+ allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
+ return getEffectiveFlag(freshRec, flag, allDs, mode="aggregate")
+ else:
+ wsId = freshRec.get("workspaceInstanceId", "")
+ allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId})
+ return getEffectiveFlagFds(freshRec, flag, allFds, mode="aggregate")
+
+
+def _computeAncestorEffectives(rootIf, rec, model, flag: str) -> List[Dict[str, Any]]:
+ """Compute the aggregate effective value for all ancestors of `rec`."""
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ collectAncestorChain, collectAncestorChainFds,
+ getEffectiveFlag, getEffectiveFlagFds,
+ )
+ effectiveKey = f"effective{flag[0].upper()}{flag[1:]}"
+ if model is DataSource:
+ connectionId = rec.get("connectionId", "")
+ allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
+ ancestors = collectAncestorChain(rec, allDs)
+ return [
+ {"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlag(a, flag, allDs, mode="aggregate")}
+ for a in ancestors
+ ]
+ else:
+ wsId = rec.get("workspaceInstanceId", "")
+ allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId})
+ ancestors = collectAncestorChainFds(rec, allFds)
+ return [
+ {"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlagFds(a, flag, allFds, mode="aggregate")}
+ for a in ancestors
+ ]
+
+
router = APIRouter(
prefix="/api/datasources",
tags=["Data Sources"],
@@ -91,26 +134,41 @@ def _updateDataSourceScope(
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
- cascadeResetDescendants,
- cascadeResetDescendantsFds,
+ cascadeResetDescendants, cascadeResetDescendantsFds,
+ getEffectiveFlag, getEffectiveFlagFds,
+ collectAncestorChain, collectAncestorChainFds,
)
rootIf = getRootInterface()
rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
- rootIf.db.recordModify(model, sourceId, {"scope": scope})
- cascaded = 0
+ # 1. Cascade reset descendants bottom-up (before modifying master)
+ resetIds: List[str] = []
if scope is not None:
if model is DataSource:
- cascaded = cascadeResetDescendants(rootIf, rec, "scope")
+ resetIds = cascadeResetDescendants(rootIf, rec, "scope")
else:
- cascaded = cascadeResetDescendantsFds(rootIf, rec, "scope")
+ resetIds = cascadeResetDescendantsFds(rootIf, rec, "scope")
+
+ # 2. Set master value last (crash-safe)
+ rootIf.db.recordModify(model, sourceId, {"scope": scope})
+
+ # 3. Compute effective + ancestor chain for response
+ updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "scope")
+ effectiveScope = _computeOwnEffective(rootIf, rec, model, sourceId, "scope")
+
logger.info(
"Updated scope=%s for %s %s (cascade-reset %d descendants)",
- scope, model.__name__, sourceId, cascaded,
+ scope, model.__name__, sourceId, len(resetIds),
)
- return {"sourceId": sourceId, "scope": scope, "updated": True, "cascadedDescendants": cascaded}
+ return {
+ "sourceId": sourceId,
+ "scope": scope,
+ "effectiveScope": effectiveScope,
+ "resetDescendantIds": resetIds,
+ "updatedAncestors": updatedAncestors,
+ }
except HTTPException:
raise
except Exception as e:
@@ -133,26 +191,39 @@ def _updateDataSourceNeutralize(
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
- cascadeResetDescendants,
- cascadeResetDescendantsFds,
+ cascadeResetDescendants, cascadeResetDescendantsFds,
)
rootIf = getRootInterface()
rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
- rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize})
- cascaded = 0
+ # 1. Cascade reset descendants bottom-up (before modifying master)
+ resetIds: List[str] = []
if neutralize is not None:
if model is DataSource:
- cascaded = cascadeResetDescendants(rootIf, rec, "neutralize")
+ resetIds = cascadeResetDescendants(rootIf, rec, "neutralize")
else:
- cascaded = cascadeResetDescendantsFds(rootIf, rec, "neutralize")
+ resetIds = cascadeResetDescendantsFds(rootIf, rec, "neutralize")
+
+ # 2. Set master value last (crash-safe)
+ rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize})
+
+ # 3. Compute effective + ancestor chain for response
+ updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "neutralize")
+ effectiveNeutralize = _computeOwnEffective(rootIf, rec, model, sourceId, "neutralize")
+
logger.info(
"Updated neutralize=%s for %s %s (cascade-reset %d descendants)",
- neutralize, model.__name__, sourceId, cascaded,
+ neutralize, model.__name__, sourceId, len(resetIds),
)
- return {"sourceId": sourceId, "neutralize": neutralize, "updated": True, "cascadedDescendants": cascaded}
+ return {
+ "sourceId": sourceId,
+ "neutralize": neutralize,
+ "effectiveNeutralize": effectiveNeutralize,
+ "resetDescendantIds": resetIds,
+ "updatedAncestors": updatedAncestors,
+ }
except HTTPException:
raise
except Exception as e:
@@ -204,46 +275,57 @@ async def _updateDataSourceRagIndex(
`True` enqueues a mini-bootstrap. `False` synchronously purges chunks.
Must be `async def` so `await startJob(...)` registers `_runJob` in the
- main event loop. Sync route → worker thread → temporary loop closes
- before the task runs → job stays stuck forever.
+ main event loop.
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
- from modules.serviceCenter.services.serviceKnowledge._inheritFlags import cascadeResetDescendants
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ cascadeResetDescendants, cascadeResetDescendantsFds,
+ )
rootIf = getRootInterface()
- rec = rootIf.db.getRecord(DataSource, sourceId)
+ rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
- rootIf.db.recordModify(DataSource, sourceId, {"ragIndexEnabled": ragIndexEnabled})
- cascaded = 0
+ # 1. Cascade reset descendants bottom-up (before modifying master)
+ resetIds: List[str] = []
if ragIndexEnabled is not None:
- cascaded = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled")
+ if model is DataSource:
+ resetIds = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled")
+ else:
+ resetIds = cascadeResetDescendantsFds(rootIf, rec, "ragIndexEnabled")
+
+ # 2. Set master value last (crash-safe)
+ rootIf.db.recordModify(model, sourceId, {"ragIndexEnabled": ragIndexEnabled})
+
logger.info(
- "Updated ragIndexEnabled=%s for DataSource %s (cascade-reset %d descendants)",
- ragIndexEnabled, sourceId, cascaded,
+ "Updated ragIndexEnabled=%s for %s %s (cascade-reset %d descendants)",
+ ragIndexEnabled, model.__name__, sourceId, len(resetIds),
)
- connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
- if ragIndexEnabled is True:
- _ensureConnectionKnowledgeFlag(rootIf, connectionId)
- from modules.serviceCenter.services.serviceBackgroundJobs import startJob
+ # Bootstrap / purge only for personal DataSource (file/folder-based RAG).
+ # FDS RAG is handled by the feature pipeline; the flag alone is enough.
+ if model is DataSource:
+ connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
+ if ragIndexEnabled is True:
+ _ensureConnectionKnowledgeFlag(rootIf, connectionId)
+ from modules.serviceCenter.services.serviceBackgroundJobs import startJob
- conn = rootIf.getUserConnectionById(connectionId) if connectionId else None
- authority = ""
- if conn:
- authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
+ conn = rootIf.getUserConnectionById(connectionId) if connectionId else None
+ authority = ""
+ if conn:
+ authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
- await startJob(
- "connection.bootstrap",
- {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]},
- triggeredBy=str(context.user.id),
- )
- elif ragIndexEnabled is False:
- from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
- purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId)
- logger.info("Purged %d index rows / %d chunks for DataSource %s",
- purgeResult.get("indexRows", 0), purgeResult.get("chunks", 0), sourceId)
+ await startJob(
+ "connection.bootstrap",
+ {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]},
+ triggeredBy=str(context.user.id),
+ )
+ elif ragIndexEnabled is False:
+ from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
+ purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId)
+ logger.info("Purged %d index rows / %d chunks for DataSource %s",
+ purgeResult.get("indexRows", 0), purgeResult.get("chunks", 0), sourceId)
import json
from modules.shared.auditLogger import audit_logger
@@ -253,10 +335,20 @@ async def _updateDataSourceRagIndex(
mandateId=context.mandateId,
category=AuditCategory.PERMISSION.value,
action="rag_index_toggled",
- details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "cascadedDescendants": cascaded}),
+ details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "resetDescendants": len(resetIds), "model": model.__name__}),
)
- return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True, "cascadedDescendants": cascaded}
+ # 3. Compute effective + ancestors for response
+ updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "ragIndexEnabled")
+ effectiveRag = _computeOwnEffective(rootIf, rec, model, sourceId, "ragIndexEnabled")
+
+ return {
+ "sourceId": sourceId,
+ "ragIndexEnabled": ragIndexEnabled,
+ "effectiveRagIndexEnabled": effectiveRag,
+ "resetDescendantIds": resetIds,
+ "updatedAncestors": updatedAncestors,
+ }
except HTTPException:
raise
except Exception as e:
@@ -339,7 +431,17 @@ def _updateDataSourceSettings(
ownerId = str(rec.get("userId") or "")
currentUserId = str(context.user.id)
if ownerId and ownerId != currentUserId and not context.isSysAdmin:
- scope = str(rec.get("scope") or "personal")
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
+ if model is DataSource:
+ connectionId = rec.get("connectionId", "")
+ allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
+ scope = str(getEffectiveFlag(rec, "scope", allDs, mode="walk"))
+ else:
+ from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource as FDS
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
+ wsId = rec.get("workspaceInstanceId", "")
+ allFds = rootIf.db.getRecordset(FDS, recordFilter={"workspaceInstanceId": wsId})
+ scope = str(getEffectiveFlagFds(rec, "scope", allFds, mode="walk"))
isMandateAdmin = getattr(context, "isMandateAdmin", False)
if scope == "personal" or not isMandateAdmin:
raise HTTPException(status_code=403, detail="Not allowed to modify this DataSource's settings")
diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py
index 99d5c4df..6a5e9eb5 100644
--- a/modules/routes/routeRagInventory.py
+++ b/modules/routes/routeRagInventory.py
@@ -86,6 +86,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"""
from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
out = []
for conn in connections:
@@ -136,8 +137,8 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
"path": dsPath,
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
- "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
- "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
+ "ragIndexEnabled": getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk"),
+ "neutralize": getEffectiveFlag(ds, "neutralize", dataSources, mode="walk"),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
"fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0),
@@ -223,13 +224,165 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
return out
+def _buildFeatureInstanceInventory(featureInstanceIds, rootIf, knowledgeIf) -> List[Dict[str, Any]]:
+ """Build per-feature-instance RAG inventory rows.
+
+ Feature-instance data lives in FileContentIndex with a non-empty
+ featureInstanceId. Additionally each feature instance may have
+ FeatureDataSource rows that define which tables/data are visible
+ as sources, with their own ragIndexEnabled flags.
+ Includes feature.bootstrap job status (running/success/error).
+ """
+ from modules.datamodels.datamodelKnowledge import FileContentIndex
+ from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
+ from modules.interfaces.interfaceFeatures import getFeatureInterface
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
+ from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
+ from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE
+
+ featureIf = getFeatureInterface(rootIf.db)
+
+ allFeatureJobs = jobService.listJobs(jobType=FEATURE_BOOTSTRAP_JOB_TYPE, limit=100)
+
+ out = []
+ for fiId in featureInstanceIds:
+ instance = featureIf.getFeatureInstance(fiId)
+ if not instance or not instance.enabled:
+ continue
+
+ indexRows = knowledgeIf.db.getRecordset(
+ FileContentIndex, recordFilter={"featureInstanceId": fiId}
+ )
+ fileIds = [
+ (r.get("id") if isinstance(r, dict) else getattr(r, "id", ""))
+ for r in indexRows
+ ]
+ fileIds = [fid for fid in fileIds if fid]
+ chunkCounts = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
+
+ statusCounts: Dict[str, int] = {}
+ for r in indexRows:
+ st = (r.get("status") if isinstance(r, dict) else getattr(r, "status", "unknown")) or "unknown"
+ statusCounts[st] = statusCounts.get(st, 0) + 1
+
+ allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": fiId})
+ dsItems = []
+ anyRagEnabled = False
+ for fds in allFds:
+ tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or ""
+ fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or ""
+ if tblName == "*" or not fCode:
+ continue
+ fdsId = fds.get("id") if isinstance(fds, dict) else getattr(fds, "id", "")
+ ragEnabled = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate")
+ if ragEnabled:
+ anyRagEnabled = True
+ dsItems.append({
+ "id": fdsId,
+ "label": (fds.get("label") if isinstance(fds, dict) else getattr(fds, "label", "")) or "",
+ "tableName": tblName,
+ "featureCode": fCode,
+ "ragIndexEnabled": ragEnabled,
+ })
+
+ fiJobs = [
+ j for j in allFeatureJobs
+ if (j.get("payload") or {}).get("workspaceInstanceId") == fiId
+ ]
+ runningJobs = [
+ {
+ "jobId": j["id"],
+ "progress": j.get("progress", 0),
+ "progressMessage": (
+ resolveJobMessage(j.get("progressMessageData"))
+ or j.get("progressMessage", "")
+ ),
+ }
+ for j in fiJobs
+ if j.get("status") in ("PENDING", "RUNNING")
+ ]
+ lastError: Optional[Dict[str, Any]] = None
+ lastSuccess: Optional[Dict[str, Any]] = None
+ for j in fiJobs:
+ jStatus = j.get("status")
+ if jStatus == "ERROR" and lastError is None:
+ lastError = {
+ "jobId": j["id"],
+ "errorMessage": j.get("errorMessage", ""),
+ "finishedAt": j.get("finishedAt"),
+ }
+ elif jStatus == "SUCCESS" and lastSuccess is None:
+ result = j.get("result") or {}
+ lastSuccess = {
+ "jobId": j["id"],
+ "finishedAt": j.get("finishedAt"),
+ "indexed": result.get("indexed", 0),
+ "skippedDuplicate": result.get("skippedDuplicate", 0),
+ "failed": result.get("failed", 0),
+ }
+ if lastError and lastSuccess:
+ break
+
+ if not indexRows and not dsItems:
+ continue
+
+ out.append({
+ "featureInstanceId": fiId,
+ "featureCode": instance.featureCode,
+ "label": instance.label or instance.featureCode,
+ "mandateId": str(instance.mandateId) if instance.mandateId else "",
+ "fileCount": len(indexRows),
+ "chunkCount": sum(chunkCounts.values()),
+ "statusCounts": statusCounts,
+ "dataSources": dsItems,
+ "ragEnabled": anyRagEnabled,
+ "runningJobs": runningJobs,
+ "lastSuccess": lastSuccess,
+ "lastError": lastError,
+ })
+ return out
+
+
+@router.get("/my-mandates")
+@limiter.limit("30/minute")
+def _getMyMandates(
+ request: Request,
+ currentUser: User = Depends(getCurrentUser),
+) -> List[Dict[str, Any]]:
+ """Return mandates where the current user has an active membership.
+
+ Used by the RAG inventory frontend to populate the mandate dropdown
+ without requiring admin rights (unlike GET /api/mandates/).
+ """
+ try:
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ rootIf = getRootInterface()
+ userMandates = rootIf.getUserMandates(str(currentUser.id))
+ result = []
+ for um in userMandates:
+ if not um.enabled:
+ continue
+ mandate = rootIf.getMandate(str(um.mandateId))
+ if not mandate or not getattr(mandate, "enabled", True):
+ continue
+ result.append({
+ "id": str(um.mandateId),
+ "name": getattr(mandate, "name", ""),
+ "label": getattr(mandate, "label", None) or getattr(mandate, "name", ""),
+ })
+ return result
+ except Exception as e:
+ logger.error("Error in RAG inventory /my-mandates: %s", e, exc_info=True)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.get("/me")
@limiter.limit("30/minute")
def _getInventoryMe(
request: Request,
currentUser: User = Depends(getCurrentUser),
) -> Dict[str, Any]:
- """Personal RAG inventory: own connections + DataSources + chunk counts."""
+ """Personal RAG inventory: own connections + DataSources + chunk counts + feature uploads."""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
@@ -243,7 +396,20 @@ def _getInventoryMe(
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
- return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
+ featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id))
+ fiIds = [
+ str(fa.featureInstanceId) for fa in featureAccesses
+ if fa.enabled and fa.featureInstanceId
+ ]
+ fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf)
+ totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems)
+ totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems)
+
+ return {
+ "connections": items,
+ "featureInstances": fiItems,
+ "totals": {"files": totalFiles, "chunks": totalChunks},
+ }
except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@@ -262,21 +428,43 @@ def _getInventoryMandate(
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
-
rootIf = getRootInterface()
knowledgeIf = getKnowledgeInterface(None)
- mandateId = str(context.mandateId) if context.mandateId else ""
+ mandateId = str(context.mandateId)
+ userId = str(context.user.id)
- from modules.datamodels.datamodelUam import UserConnection
- allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId})
- connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections]
+ userMandates = rootIf.getUserMandates(userId)
+ isMember = any(
+ getattr(um, "mandateId", None) == mandateId and um.enabled
+ for um in userMandates
+ )
+ if not isMember and not context.isSysAdmin:
+ raise HTTPException(status_code=403, detail=routeApiMsg("No membership in this mandate"))
- items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
+ mandateMembers = rootIf.getUserMandatesByMandate(mandateId)
+ memberUserIds = {getattr(um, "userId", None) for um in mandateMembers}
+ memberUserIds.discard(None)
+
+ allConnections = []
+ for uid in memberUserIds:
+ allConnections.extend(rootIf.getUserConnections(uid))
+
+ items = _buildConnectionInventory(allConnections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId)
- return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
+ mandateInstances = rootIf.getFeatureInstancesByMandate(mandateId, enabledOnly=True)
+ fiIds = [str(inst.id) for inst in mandateInstances if inst.id]
+ fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf)
+ totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems)
+ totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems)
+
+ return {
+ "connections": items,
+ "featureInstances": fiItems,
+ "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes},
+ }
except HTTPException:
raise
except Exception as e:
@@ -308,7 +496,22 @@ def _getInventoryPlatform(
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
- return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
+ from modules.datamodels.datamodelFeatures import FeatureInstance
+ allInstances = rootIf.db.getRecordset(FeatureInstance, recordFilter={"enabled": True})
+ fiIds = [
+ (r.get("id") if isinstance(r, dict) else getattr(r, "id", ""))
+ for r in allInstances
+ ]
+ fiIds = [fid for fid in fiIds if fid]
+ fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf)
+ totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems)
+ totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems)
+
+ return {
+ "connections": items,
+ "featureInstances": fiItems,
+ "totals": {"files": totalFiles, "chunks": totalChunks},
+ }
except HTTPException:
raise
except Exception as e:
@@ -345,8 +548,9 @@ async def _reindexConnection(
if str(conn.userId) != str(currentUser.id):
raise HTTPException(status_code=403, detail="Not your connection")
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
- ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))]
+ ragDs = [ds for ds in dataSources if getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk") is True]
if not ragDs:
return {"status": "skipped", "reason": "no_rag_enabled_datasources"}
@@ -368,6 +572,47 @@ async def _reindexConnection(
raise HTTPException(status_code=500, detail=str(e))
+@router.post("/reindex-feature/{workspaceInstanceId}")
+@limiter.limit("10/minute")
+async def _reindexFeature(
+ request: Request,
+ workspaceInstanceId: str,
+ currentUser: User = Depends(getCurrentUser),
+) -> Dict[str, Any]:
+ """Re-trigger feature data bootstrap for a workspace instance.
+
+ Indexes all RAG-enabled FeatureDataSource rows into the knowledge store.
+ Must be ``async def`` so ``await startJob(...)`` registers in the main loop.
+ """
+ try:
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.serviceCenter.services.serviceBackgroundJobs import startJob
+ from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE
+
+ rootIf = getRootInterface()
+ featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id))
+ hasAccess = any(
+ str(fa.featureInstanceId) == workspaceInstanceId and fa.enabled
+ for fa in featureAccesses
+ )
+ if not hasAccess and not getattr(currentUser, "isSysAdmin", False):
+ raise HTTPException(status_code=403, detail="No access to this feature instance")
+
+ jobId = await startJob(
+ FEATURE_BOOTSTRAP_JOB_TYPE,
+ {"workspaceInstanceId": workspaceInstanceId},
+ triggeredBy=str(currentUser.id),
+ )
+
+ logger.info("Feature reindex triggered for workspace %s (jobId=%s)", workspaceInstanceId, jobId)
+ return {"status": "queued", "workspaceInstanceId": workspaceInstanceId, "jobId": jobId}
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error("Error triggering feature reindex: %s", e, exc_info=True)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.get("/jobs")
@limiter.limit("60/minute")
def _getActiveJobs(
diff --git a/modules/security/rbac.py b/modules/security/rbac.py
index bec0b70e..59f8f55f 100644
--- a/modules/security/rbac.py
+++ b/modules/security/rbac.py
@@ -341,11 +341,10 @@ class RbacClass:
return []
try:
- conn = self.dbApp.connection
roleIds = set()
-
+
# 1. Mandant-Rollen via UserMandate → UserMandateRole (SINGLE Query)
- with conn.cursor() as cursor:
+ with self.dbApp.borrowCursor() as cursor:
cursor.execute(
"""
SELECT umr."roleId"
@@ -357,10 +356,10 @@ class RbacClass:
)
mandateRoles = cursor.fetchall()
roleIds.update(r["roleId"] for r in mandateRoles if r.get("roleId"))
-
+
# 2. Instanz-Rollen via FeatureAccess → FeatureAccessRole (SINGLE Query)
if featureInstanceId:
- with conn.cursor() as cursor:
+ with self.dbApp.borrowCursor() as cursor:
cursor.execute(
"""
SELECT far."roleId"
@@ -372,14 +371,13 @@ class RbacClass:
)
instanceRoles = cursor.fetchall()
roleIds.update(r["roleId"] for r in instanceRoles if r.get("roleId"))
-
+
if not roleIds:
return []
-
+
# 3. BULK Query: Alle Regeln für alle Rollen + zugehörige Role-Daten
- # SINGLE Query mit JOIN statt N+1
roleIdsList = list(roleIds)
- with conn.cursor() as cursor:
+ with self.dbApp.borrowCursor() as cursor:
cursor.execute(
"""
SELECT ar.*, r."mandateId" as "roleMandateId",
diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py
index fff1bcb3..dbd28dd4 100644
--- a/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py
+++ b/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py
@@ -67,7 +67,12 @@ def _registerDataSourceTools(registry: ToolRegistry, services):
sourceType = ds.get("sourceType", "")
path = ds.get("path", "/")
label = ds.get("label", "")
- neutralize = bool(ds.get("neutralize", False))
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
+ from modules.datamodels.datamodelDataSource import DataSource
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ rootIf = getRootInterface()
+ allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
+ neutralize = bool(getEffectiveFlag(ds, "neutralize", allConnDs or [ds], mode="walk"))
service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType)
if not connectionId:
raise ValueError(f"DataSource '{dsId}' has no connectionId")
diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
index bdb3d23b..2ebc2720 100644
--- a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
+++ b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
@@ -110,9 +110,11 @@ def _registerFeatureSubAgentTools(registry: ToolRegistry, services):
recordFilter={"featureInstanceId": featureInstanceId, "workspaceInstanceId": workspaceInstanceId},
)
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
+ _fdsAll = featureDataSources or []
_anySourceNeutralize = any(
- bool(ds.get("neutralize", False) if isinstance(ds, dict) else getattr(ds, "neutralize", False))
- for ds in (featureDataSources or [])
+ getEffectiveFlagFds(ds, "neutralize", _fdsAll, mode="walk") is True
+ for ds in _fdsAll
)
neutralizeFieldsPerTable: Dict[str, List[str]] = {}
diff --git a/modules/serviceCenter/services/serviceAgent/featureDataProvider.py b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py
index d7707bdf..27ec36b2 100644
--- a/modules/serviceCenter/services/serviceAgent/featureDataProvider.py
+++ b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py
@@ -95,8 +95,7 @@ class FeatureDataProvider:
def getActualColumns(self, tableName: str) -> List[str]:
"""Read real column names from PostgreSQL information_schema."""
try:
- conn = self._db.connection
- with conn.cursor() as cur:
+ with self._db.borrowCursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
@@ -131,7 +130,6 @@ class FeatureDataProvider:
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
"""
_validateTableName(tableName)
- conn = self._db.connection
if fields:
invalid = [f for f in fields if not _isValidIdentifier(f)]
@@ -141,7 +139,7 @@ class FeatureDataProvider:
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
}
- scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
+ scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db)
extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"]
@@ -152,7 +150,7 @@ class FeatureDataProvider:
t0 = time.time()
try:
- with conn.cursor() as cur:
+ with self._db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0
@@ -179,10 +177,6 @@ class FeatureDataProvider:
_debugQueryLog("browseTable", tableName, {
"fields": fields, "limit": limit, "offset": offset,
}, errResult, elapsed)
- try:
- conn.rollback()
- except Exception:
- pass
return errResult
def aggregateTable(
@@ -208,8 +202,7 @@ class FeatureDataProvider:
if groupBy and not _isValidIdentifier(groupBy):
return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"}
- conn = self._db.connection
- scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
+ scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db)
extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"]
@@ -220,7 +213,7 @@ class FeatureDataProvider:
t0 = time.time()
try:
- with conn.cursor() as cur:
+ with self._db.borrowCursor() as cur:
if groupBy:
sql = (
f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" '
@@ -253,10 +246,6 @@ class FeatureDataProvider:
_debugQueryLog("aggregateTable", tableName, {
"aggregate": aggregate, "field": field, "groupBy": groupBy,
}, errResult, elapsed)
- try:
- conn.rollback()
- except Exception:
- pass
return errResult
def queryTable(
@@ -277,7 +266,6 @@ class FeatureDataProvider:
``extraFilters`` are mandatory record-level scoping filters injected by the pipeline.
"""
_validateTableName(tableName)
- conn = self._db.connection
if fields:
invalid = [f for f in fields if not _isValidIdentifier(f)]
@@ -287,7 +275,7 @@ class FeatureDataProvider:
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
}
- scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
+ scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db)
combinedFilters = list(filters or []) + list(extraFilters or [])
extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None)
@@ -300,7 +288,7 @@ class FeatureDataProvider:
t0 = time.time()
try:
- with conn.cursor() as cur:
+ with self._db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0
@@ -329,10 +317,6 @@ class FeatureDataProvider:
"filters": filters, "fields": fields, "orderBy": orderBy,
"limit": limit, "offset": offset,
}, errResult, elapsed)
- try:
- conn.rollback()
- except Exception:
- pass
return errResult
@@ -343,13 +327,13 @@ class FeatureDataProvider:
_instanceColCache: Dict[str, str] = {}
-def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str:
+def _resolveInstanceColumn(tableName: str, db=None) -> str:
"""Detect whether the table uses ``instanceId`` or ``featureInstanceId``."""
if tableName in _instanceColCache:
return _instanceColCache[tableName]
- if dbConnection:
+ if db:
try:
- with dbConnection.cursor() as cur:
+ with db.borrowCursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
@@ -378,14 +362,14 @@ def _isValidIdentifier(name: str) -> bool:
return name.isidentifier()
-def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]:
+def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, db=None, dbConnection=None) -> Dict[str, Any]:
"""Build the mandatory WHERE clause that scopes rows to the feature instance.
Feature tables use either ``instanceId`` (commcoach, teamsbot) or
``featureInstanceId`` (trustee) as the FK. We detect the actual column
- from ``information_schema`` when a DB connection is provided.
+ from ``information_schema`` when a DB connector is provided.
"""
- instanceCol = _resolveInstanceColumn(tableName, dbConnection)
+ instanceCol = _resolveInstanceColumn(tableName, db or dbConnection)
conditions = []
params = []
diff --git a/modules/serviceCenter/services/serviceKnowledge/_buildTree.py b/modules/serviceCenter/services/serviceKnowledge/_buildTree.py
new file mode 100644
index 00000000..9179f3d8
--- /dev/null
+++ b/modules/serviceCenter/services/serviceKnowledge/_buildTree.py
@@ -0,0 +1,1020 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""Generic UDB Tree builder.
+
+The UDB shows three logical hierarchies as a single user-facing tree:
+ 1. Personal connections: UserConnection -> Service -> Folder -> File
+ 2. Mandate groups -> Feature instances -> FDS Workspace(*) -> FDS Table -> FDS Record
+ 3. (Settings/diagnostics nodes can be added later under the same model.)
+
+For every visible node the UI needs:
+ - a stable `key` (used both for expand-state and as parent reference)
+ - a `kind`, `label`, optional `icon`
+ - effective values for all three flags (neutralize, scope, ragIndexEnabled)
+ - whether a backing DB record exists (`dataSourceId` + `modelType`)
+ - whether the node has children to expand
+
+This module exposes one function: `getChildrenForParents(parents, ...)`.
+The caller asks for the children of a list of parent keys. The orchestrator
+does NOT decide what to expand; it only returns the children of what was
+asked for. This keeps the contract minimal and predictable.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List, Optional, Tuple
+
+from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ resolveEffectiveForPath,
+ resolveEffectiveForFds,
+ _normalisePath,
+)
+
+logger = logging.getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# Key encoding / decoding
+# ---------------------------------------------------------------------------
+# Format: "|||..." for data-bearing keys.
+# Synthetic container keys use a single literal token without separator.
+#
+# Top-level (parent=None) returns:
+# personalRoot (synthetic, groups all UserConnections)
+# mgrp| (one per accessible mandate)
+#
+# Data-bearing:
+# conn|
+# svc||
+# ds|||
+# mgrp|
+# feat|||
+# fdsws|| (synthetic '*' wildcard)
+# fdstbl||
+# fdsrec|||
+
+_KEY_SEP = "|"
+
+# Stable, parseable synthetic-container key. Never encoded with `_encode`
+# (no payload parts), always emitted/matched as literal.
+_KEY_PERSONAL_ROOT = "personalRoot"
+
+
+def _decode(key: str) -> Tuple[str, List[str]]:
+ parts = key.split(_KEY_SEP)
+ return parts[0], parts[1:]
+
+
+def _encode(kind: str, *parts: str) -> str:
+ return _KEY_SEP.join((kind, *parts))
+
+
+# ---------------------------------------------------------------------------
+# Sourcetype mapping (was hard-coded in frontend; now backend authority)
+# ---------------------------------------------------------------------------
+_SERVICE_TO_SOURCE_TYPE: Dict[str, str] = {
+ "sharepoint": "sharepointFolder",
+ "onedrive": "onedriveFolder",
+ "outlook": "outlookFolder",
+ "drive": "googleDriveFolder",
+ "gmail": "gmailFolder",
+ "files": "ftpFolder",
+ "clickup": "clickup",
+ "kdrive": "kdriveFolder",
+ "mail": "mailFolder",
+ "calendar": "calendarFolder",
+ "contact": "contactFolder",
+}
+
+_SERVICE_LABELS: Dict[str, str] = {
+ "sharepoint": "SharePoint",
+ "outlook": "Outlook",
+ "teams": "Teams",
+ "onedrive": "OneDrive",
+ "drive": "Google Drive",
+ "gmail": "Gmail",
+ "files": "Files (FTP)",
+ "kdrive": "kDrive",
+ "calendar": "Calendar",
+ "contact": "Contacts",
+}
+
+
+# ---------------------------------------------------------------------------
+# Per-node effective-value helpers
+# ---------------------------------------------------------------------------
+
+def _effectiveTripletDs(
+ connectionId: str,
+ sourceType: str,
+ path: str,
+ allDs: List[Dict[str, Any]],
+) -> Dict[str, Any]:
+ """Return {effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled}
+ for an arbitrary DS coordinate (whether or not a record exists)."""
+ out = resolveEffectiveForPath(connectionId, sourceType, path, allDs, mode="aggregate")
+ return {
+ "effectiveNeutralize": out.get("effectiveNeutralize", False),
+ "effectiveScope": out.get("effectiveScope", "personal"),
+ "effectiveRagIndexEnabled": out.get("effectiveRagIndexEnabled", False),
+ }
+
+
+def _effectiveTripletFds(
+ featureInstanceId: str,
+ tableName: str,
+ recordFilter: Optional[Dict[str, str]],
+ allFds: List[Dict[str, Any]],
+) -> Dict[str, Any]:
+ """Return effective-triplet for an FDS coordinate."""
+ out = resolveEffectiveForFds(featureInstanceId, tableName, recordFilter, allFds, mode="aggregate")
+ return {
+ "effectiveNeutralize": out.get("effectiveNeutralize", False),
+ "effectiveScope": out.get("effectiveScope", "personal"),
+ "effectiveRagIndexEnabled": out.get("effectiveRagIndexEnabled", False),
+ }
+
+
+def _findDsRecord(
+ allDs: List[Dict[str, Any]],
+ connectionId: str,
+ sourceType: str,
+ path: str,
+) -> Optional[Dict[str, Any]]:
+ norm = _normalisePath(path)
+ for ds in allDs:
+ if (
+ ds.get("connectionId") == connectionId
+ and ds.get("sourceType") == sourceType
+ and _normalisePath(ds.get("path")) == norm
+ ):
+ return ds
+ return None
+
+
+def _findFdsRecord(
+ allFds: List[Dict[str, Any]],
+ featureInstanceId: str,
+ tableName: str,
+ recordFilter: Optional[Dict[str, str]] = None,
+) -> Optional[Dict[str, Any]]:
+ """Find a FeatureDataSource record by featureInstanceId + tableName.
+
+ `allFds` is already scoped to the workspace (loaded with
+ recordFilter={'workspaceInstanceId': wsInstanceId}), so the
+ distinguishing coordinate is featureInstanceId + tableName.
+ """
+ target = recordFilter or None
+ for fds in allFds:
+ if (
+ fds.get("featureInstanceId") == featureInstanceId
+ and fds.get("tableName") == tableName
+ and (fds.get("recordFilter") or None) == target
+ ):
+ return fds
+ return None
+
+
+# ---------------------------------------------------------------------------
+# Synthetic container helpers
+# ---------------------------------------------------------------------------
+
+def _emptyTriplet() -> Dict[str, Any]:
+ """Synthetic container nodes carry no DB record and no inherited flags.
+ Backend reports neutral defaults so the UI never reads stale values for them."""
+ return {
+ "effectiveNeutralize": False,
+ "effectiveScope": "personal",
+ "effectiveRagIndexEnabled": False,
+ }
+
+
+def _syntheticNode(
+ key: str,
+ parentKey: Optional[str],
+ label: str,
+ icon: str,
+ displayOrder: int,
+ defaultExpanded: bool = False,
+) -> Dict[str, Any]:
+ """Build a synthetic container node (no DB record, not flag-toggleable)."""
+ return {
+ "key": key,
+ "kind": "synthRoot",
+ "parentKey": parentKey,
+ "label": label,
+ "icon": icon,
+ "hasChildren": True,
+ "dataSourceId": None,
+ "modelType": None,
+ **_emptyTriplet(),
+ "supportsRag": False,
+ "canBeAdded": False,
+ "displayOrder": displayOrder,
+ "defaultExpanded": defaultExpanded,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Top-level (parent = None) -> personalRoot + mandate groups (flat layout)
+# ---------------------------------------------------------------------------
+
+def _topLevel(
+ instanceId: str,
+ context: Any,
+ rootIf: Any,
+ _allDs: List[Dict[str, Any]],
+ allFds: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Return the visible top-level: 'personalRoot' first, then one node per
+ accessible mandate group. Both layers are marked `defaultExpanded=True`
+ so the UI opens down to the data-source level on first render.
+ """
+ nodes: List[Dict[str, Any]] = [
+ _syntheticNode(
+ key=_KEY_PERSONAL_ROOT,
+ parentKey=None,
+ label=resolveTextSafe("Persönliche Quellen"),
+ icon="person",
+ displayOrder=0,
+ defaultExpanded=True,
+ )
+ ]
+ nodes.extend(_listMandateGroups(instanceId, context, rootIf, allFds))
+ return nodes
+
+
+# ---------------------------------------------------------------------------
+# Children of personalRoot -> active UserConnections
+# ---------------------------------------------------------------------------
+
+def _personalRootChildren(
+ instanceId: str,
+ context: Any,
+ allDs: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Return one node per active UserConnection of the current user."""
+ from modules.serviceCenter import getService
+ from modules.serviceCenter.context import ServiceCenterContext
+
+ mandateId = getattr(context, "mandateId", "") or ""
+ ctx = ServiceCenterContext(
+ user=context.user,
+ mandate_id=mandateId,
+ feature_instance_id=instanceId,
+ )
+ chatService = getService("chat", ctx)
+ connections = chatService.getUserConnections() or []
+
+ nodes: List[Dict[str, Any]] = []
+ for c in connections:
+ conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {})
+ status = conn.get("status")
+ if hasattr(status, "value"):
+ status = status.value
+ if status != "active":
+ continue
+ authority = conn.get("authority")
+ if hasattr(authority, "value"):
+ authority = authority.value
+ connId = conn.get("id") or ""
+ label = conn.get("externalEmail") or conn.get("externalUsername") or authority or ""
+ # Connection root = path '/' on its authority sourceType.
+ triplet = _effectiveTripletDs(connId, str(authority), "/", allDs)
+ rec = _findDsRecord(allDs, connId, str(authority), "/")
+ nodes.append({
+ "key": _encode("conn", connId),
+ "kind": "connection",
+ "parentKey": _KEY_PERSONAL_ROOT,
+ "label": label,
+ "icon": str(authority),
+ "hasChildren": True,
+ "dataSourceId": rec.get("id") if rec else None,
+ "modelType": "DataSource" if rec else None,
+ **triplet,
+ "supportsRag": True,
+ "canBeAdded": rec is None,
+ "authority": authority,
+ "connectionId": connId,
+ })
+ return nodes
+
+
+# ---------------------------------------------------------------------------
+# Mandate-group nodes (rendered top-level next to personalRoot)
+# ---------------------------------------------------------------------------
+
+def _listMandateGroups(
+ _instanceId: str,
+ context: Any,
+ rootIf: Any,
+ _allFds: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Return one mandate-group node per accessible mandate that has at least
+ one enabled feature instance with registered DATA objects.
+
+ Emitted at the top level (parentKey=None). `defaultExpanded=True` so the
+ UI shows feature-instance children (= mandate data sources) without a
+ second user click.
+ """
+ from modules.security.rbacCatalog import getCatalogService
+ from modules.datamodels.datamodelUam import Mandate
+
+ userId = str(context.user.id)
+ catalog = getCatalogService()
+ featureCodesWithData = catalog.getFeaturesWithDataObjects()
+ userMandates = rootIf.getUserMandates(userId)
+
+ wsMandateId = getattr(context, "mandateId", None)
+ allowedMandateIds = {um.mandateId for um in (userMandates or [])}
+ if wsMandateId and wsMandateId in allowedMandateIds:
+ allowedMandateIds = {wsMandateId}
+
+ mandateLabels: Dict[str, str] = {}
+ for um in userMandates or []:
+ if um.mandateId not in allowedMandateIds:
+ continue
+ try:
+ rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId})
+ if rows:
+ m = rows[0]
+ mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId
+ except Exception:
+ mandateLabels[um.mandateId] = um.mandateId
+
+ nodes: List[Dict[str, Any]] = []
+ seenMandates: set = set()
+ for um in userMandates or []:
+ mid = um.mandateId
+ if mid in seenMandates or mid not in allowedMandateIds:
+ continue
+ seenMandates.add(mid)
+ instances = rootIf.getFeatureInstancesByMandate(mid)
+ hasFeature = False
+ for inst in instances:
+ if inst.enabled and inst.featureCode in featureCodesWithData:
+ fa = rootIf.getFeatureAccess(userId, inst.id)
+ if fa and fa.enabled:
+ hasFeature = True
+ break
+ if not hasFeature:
+ continue
+ nodes.append({
+ "key": _encode("mgrp", mid),
+ "kind": "mandateGroup",
+ "parentKey": None,
+ "label": mandateLabels.get(mid, mid),
+ "icon": "mandate",
+ "hasChildren": True,
+ "dataSourceId": None,
+ "modelType": None,
+ **_emptyTriplet(),
+ "supportsRag": False,
+ "canBeAdded": False,
+ "mandateId": mid,
+ "defaultExpanded": True,
+ })
+ return nodes
+
+
+# ---------------------------------------------------------------------------
+# Children of a connection -> services
+# ---------------------------------------------------------------------------
+
+async def _connectionServices(
+ instanceId: str,
+ context: Any,
+ connectionId: str,
+ allDs: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ from modules.connectors.connectorResolver import ConnectorResolver
+ from modules.serviceCenter import getService
+ from modules.serviceCenter.context import ServiceCenterContext
+
+ mandateId = getattr(context, "mandateId", "") or ""
+ ctx = ServiceCenterContext(
+ user=context.user,
+ mandate_id=mandateId,
+ feature_instance_id=instanceId,
+ )
+ chatService = getService("chat", ctx)
+ securityService = getService("security", ctx)
+ from modules.features.workspace.routeFeatureWorkspace import _buildResolverDbInterface
+ dbInterface = _buildResolverDbInterface(chatService)
+ resolver = ConnectorResolver(securityService, dbInterface)
+ try:
+ provider = await resolver.resolve(connectionId)
+ services = provider.getAvailableServices()
+ except Exception as exc:
+ logger.error("Tree: cannot resolve services for connection %s: %s", connectionId, exc)
+ return []
+
+ nodes: List[Dict[str, Any]] = []
+ for service in services or []:
+ sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service)
+ triplet = _effectiveTripletDs(connectionId, sourceType, "/", allDs)
+ rec = _findDsRecord(allDs, connectionId, sourceType, "/")
+ nodes.append({
+ "key": _encode("svc", connectionId, service),
+ "kind": "service",
+ "parentKey": _encode("conn", connectionId),
+ "label": _SERVICE_LABELS.get(service, service),
+ "icon": service,
+ "hasChildren": True,
+ "dataSourceId": rec.get("id") if rec else None,
+ "modelType": "DataSource" if rec else None,
+ **triplet,
+ "supportsRag": True,
+ "canBeAdded": rec is None,
+ "connectionId": connectionId,
+ "service": service,
+ "sourceType": sourceType,
+ "path": "/",
+ })
+ return nodes
+
+
+# ---------------------------------------------------------------------------
+# Children of a folder/service -> next-level folders+files via browse
+# ---------------------------------------------------------------------------
+
+async def _browseChildren(
+ instanceId: str,
+ context: Any,
+ connectionId: str,
+ service: str,
+ sourceType: str,
+ parentPath: str,
+ allDs: List[Dict[str, Any]],
+ parentKey: Optional[str] = None,
+) -> List[Dict[str, Any]]:
+ from modules.connectors.connectorResolver import ConnectorResolver
+ from modules.serviceCenter import getService
+ from modules.serviceCenter.context import ServiceCenterContext
+
+ mandateId = getattr(context, "mandateId", "") or ""
+ ctx = ServiceCenterContext(
+ user=context.user,
+ mandate_id=mandateId,
+ feature_instance_id=instanceId,
+ )
+ chatService = getService("chat", ctx)
+ securityService = getService("security", ctx)
+ from modules.features.workspace.routeFeatureWorkspace import _buildResolverDbInterface
+ dbInterface = _buildResolverDbInterface(chatService)
+ resolver = ConnectorResolver(securityService, dbInterface)
+ try:
+ adapter = await resolver.resolveService(connectionId, service)
+ entries = await adapter.browse(parentPath, filter=None)
+ except Exception as exc:
+ logger.error("Tree: cannot browse %s on connection %s path=%s: %s", service, connectionId, parentPath, exc)
+ return []
+
+ # Children parentKey must equal the key the caller asked for (= the
+ # currently-expanded node in the UI). If the caller doesn't pass an
+ # explicit key, fall back to the encoded ds-coordinate.
+ effectiveParentKey = parentKey if parentKey is not None else _encode("ds", connectionId, sourceType, parentPath)
+ nodes: List[Dict[str, Any]] = []
+ for e in entries or []:
+ path = getattr(e, "path", "") or ""
+ kind = "folder" if getattr(e, "isFolder", False) else "file"
+ triplet = _effectiveTripletDs(connectionId, sourceType, path, allDs)
+ rec = _findDsRecord(allDs, connectionId, sourceType, path)
+ nodes.append({
+ "key": _encode("ds", connectionId, sourceType, path),
+ "kind": kind,
+ "parentKey": effectiveParentKey,
+ "label": getattr(e, "name", "") or path,
+ "icon": kind,
+ "hasChildren": kind == "folder",
+ "dataSourceId": rec.get("id") if rec else None,
+ "modelType": "DataSource" if rec else None,
+ **triplet,
+ "supportsRag": True,
+ "canBeAdded": rec is None,
+ "connectionId": connectionId,
+ "service": service,
+ "sourceType": sourceType,
+ "path": path,
+ })
+ return nodes
+
+
+# ---------------------------------------------------------------------------
+# Mandate group -> feature connections
+# ---------------------------------------------------------------------------
+
+def _featureConnectionsForMandate(
+ instanceId: str,
+ context: Any,
+ rootIf: Any,
+ mandateId: str,
+ allFds: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ from modules.security.rbacCatalog import getCatalogService
+
+ userId = str(context.user.id)
+ catalog = getCatalogService()
+ featureCodesWithData = catalog.getFeaturesWithDataObjects()
+ instances = rootIf.getFeatureInstancesByMandate(mandateId)
+
+ parentKey = _encode("mgrp", mandateId)
+ nodes: List[Dict[str, Any]] = []
+ for inst in instances or []:
+ if not inst.enabled:
+ continue
+ if inst.featureCode not in featureCodesWithData:
+ continue
+ fa = rootIf.getFeatureAccess(userId, inst.id)
+ if not fa or not fa.enabled:
+ continue
+ # Effective values come from the FDS workspace-wildcard for this featureInstanceId
+ wsId = inst.id
+ triplet = _effectiveTripletFds(wsId, "*", None, allFds)
+ rec = _findFdsRecord(allFds, wsId, "*", None)
+ featureDef = catalog.getFeatureDefinition(inst.featureCode) or {}
+ nodes.append({
+ "key": _encode("feat", mandateId, inst.featureCode, inst.id),
+ "kind": "featureNode",
+ "parentKey": parentKey,
+ "label": inst.label or inst.featureCode,
+ "icon": featureDef.get("icon", "mdi-database"),
+ "hasChildren": True,
+ "dataSourceId": rec.get("id") if rec else None,
+ "modelType": "FeatureDataSource" if rec else None,
+ **triplet,
+ "supportsRag": True,
+ "canBeAdded": rec is None,
+ "featureInstanceId": wsId,
+ "featureCode": inst.featureCode,
+ "mandateId": mandateId,
+ "tableName": "*",
+ })
+ return nodes
+
+
+# ---------------------------------------------------------------------------
+# Feature node -> tables
+# ---------------------------------------------------------------------------
+
+def _featureTables(
+ context: Any,
+ rootIf: Any,
+ parentKey: str,
+ featureInstanceId: str,
+ featureCode: str,
+ allFds: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ from modules.security.rbacCatalog import getCatalogService
+
+ inst = rootIf.getFeatureInstance(featureInstanceId)
+ if not inst:
+ return []
+ catalog = getCatalogService()
+ try:
+ from modules.security.rbac import RbacClass
+ from modules.security.rootAccess import getRootDbAppConnector
+ dbApp = getRootDbAppConnector()
+ rbac = RbacClass(dbApp, dbApp=dbApp)
+ accessible = catalog.getAccessibleDataObjects(
+ featureCode=inst.featureCode,
+ rbacInstance=rbac,
+ user=context.user,
+ mandateId=str(inst.mandateId) if inst.mandateId else "",
+ featureInstanceId=featureInstanceId,
+ )
+ except Exception:
+ accessible = catalog.getDataObjects(inst.featureCode)
+
+ accessibleKeys = {obj.get("objectKey", "") for obj in accessible}
+
+ nodes: List[Dict[str, Any]] = []
+ for obj in catalog.getDataObjects(inst.featureCode):
+ meta = obj.get("meta", {})
+ if meta.get("wildcard") or meta.get("isGroup"):
+ continue
+ objectKey = obj.get("objectKey", "")
+ if objectKey not in accessibleKeys:
+ continue
+ tableName = meta.get("table", "")
+ if not tableName:
+ continue
+ triplet = _effectiveTripletFds(featureInstanceId, tableName, None, allFds)
+ rec = _findFdsRecord(allFds, featureInstanceId, tableName, None)
+ fields = meta.get("fields") if isinstance(meta, dict) else None
+ hasFields = bool(isinstance(fields, list) and len(fields) > 0)
+ # Surface the persisted per-field neutralize list so the UI can
+ # render & toggle field-level icons without an extra GET.
+ neutralizeFields: List[str] = []
+ if rec and isinstance(rec.get("neutralizeFields"), list):
+ neutralizeFields = [f for f in rec["neutralizeFields"] if isinstance(f, str)]
+ nodes.append({
+ "key": _encode("fdstbl", featureInstanceId, tableName),
+ "kind": "fdsTable",
+ "parentKey": parentKey,
+ "label": resolveTextSafe(obj.get("label", "")) or tableName,
+ "icon": "table",
+ # Children = the per-column field nodes. Only emitted when the
+ # data-object metadata declared a non-empty `fields` list.
+ "hasChildren": hasFields,
+ "dataSourceId": rec.get("id") if rec else None,
+ "modelType": "FeatureDataSource" if rec else None,
+ **triplet,
+ "supportsRag": True,
+ "canBeAdded": rec is None,
+ "featureInstanceId": featureInstanceId,
+ "featureCode": featureCode,
+ "tableName": tableName,
+ "objectKey": objectKey,
+ "neutralizeFields": neutralizeFields,
+ })
+ return nodes
+
+
+def _featureTableFields(
+ parentKey: str,
+ featureInstanceId: str,
+ tableName: str,
+ fieldNames: List[str],
+ allFds: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Emit one node per declared column of a feature data table.
+
+ Per-field neutralize semantics:
+ - The table-level FDS record carries `neutralizeFields: List[str]`.
+ - A field is "effectively neutralized" iff its name is in that list
+ OR the table's effective `neutralize` is True (blanket).
+ - Only `neutralize` is meaningful per-field; `scope` and `ragIndexEnabled`
+ are inherited from the parent table and not toggleable here.
+ """
+ rec = _findFdsRecord(allFds, featureInstanceId, tableName, None)
+ tableNeutralize = bool(rec.get("neutralize")) if rec else False
+ neutralizeFields = rec.get("neutralizeFields") if rec else None
+ if not isinstance(neutralizeFields, list):
+ neutralizeFields = []
+
+ nodes: List[Dict[str, Any]] = []
+ for field in fieldNames:
+ if not field:
+ continue
+ fieldNeutralized = bool(tableNeutralize or field in neutralizeFields)
+ nodes.append({
+ "key": _encode("fdsfld", featureInstanceId, tableName, field),
+ "kind": "fdsField",
+ "parentKey": parentKey,
+ "label": field,
+ "icon": "field",
+ "hasChildren": False,
+ "dataSourceId": rec.get("id") if rec else None,
+ "modelType": "FeatureDataSource" if rec else None,
+ "effectiveNeutralize": fieldNeutralized,
+ # Field-level scope/RAG do not exist as a concept; the FE hides
+ # those affordances when supportsRag=False. We still need
+ # `effectiveScope` + `effectiveRagIndexEnabled` for the
+ # contract; they reflect the parent's effective values so the
+ # backend stays single source of truth.
+ "effectiveScope": "personal",
+ "effectiveRagIndexEnabled": False,
+ "supportsRag": False,
+ "canBeAdded": rec is None,
+ "featureInstanceId": featureInstanceId,
+ "tableName": tableName,
+ "fieldName": field,
+ })
+ return nodes
+
+
+def resolveTextSafe(label: Any) -> str:
+ try:
+ from modules.shared.i18nRegistry import resolveText
+ return resolveText(label)
+ except Exception:
+ return str(label or "")
+
+
+# ---------------------------------------------------------------------------
+# Public entrypoint
+# ---------------------------------------------------------------------------
+
+async def getChildrenForParents(
+ instanceId: str,
+ parents: List[Optional[str]],
+ context: Any,
+) -> Dict[str, List[Dict[str, Any]]]:
+ """Return per-parent children lists.
+
+ `parents` is a list with `None` representing the top-level. Order is preserved.
+ Returns a dict keyed by parent key (or '__root__' for None).
+
+ Each child is a fully-rendered TreeNode dict (see module docstring for shape).
+ """
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.datamodels.datamodelDataSource import DataSource
+ from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
+
+ rootIf = getRootInterface()
+
+ # Pre-load DS (per user) and FDS (per workspace) once for the whole request.
+ userId = str(context.user.id)
+ allDs = rootIf.db.getRecordset(DataSource, recordFilter={"userId": userId}) or []
+ allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": instanceId}) or []
+
+ out: Dict[str, List[Dict[str, Any]]] = {}
+
+ for parentKey in parents:
+ if parentKey is None:
+ try:
+ out["__root__"] = _topLevel(instanceId, context, rootIf, allDs, allFds)
+ except Exception as exc:
+ logger.exception("Tree top-level failed: %s", exc)
+ out["__root__"] = []
+ continue
+
+ try:
+ kind, parts = _decode(parentKey)
+ except Exception:
+ out[parentKey] = []
+ continue
+
+ try:
+ if parentKey == _KEY_PERSONAL_ROOT:
+ out[parentKey] = _personalRootChildren(instanceId, context, allDs)
+
+ elif kind == "conn" and len(parts) == 1:
+ out[parentKey] = await _connectionServices(instanceId, context, parts[0], allDs)
+
+ elif kind == "svc" and len(parts) == 2:
+ connId, service = parts
+ sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service)
+ out[parentKey] = await _browseChildren(
+ instanceId, context, connId, service, sourceType, "/", allDs,
+ parentKey=parentKey,
+ )
+
+ elif kind == "ds" and len(parts) == 3:
+ connId, sourceType, path = parts
+ # Determine service from sourceType (reverse map)
+ service = _reverseService(sourceType)
+ out[parentKey] = await _browseChildren(
+ instanceId, context, connId, service, sourceType, path, allDs,
+ parentKey=parentKey,
+ )
+
+ elif kind == "mgrp" and len(parts) == 1:
+ out[parentKey] = _featureConnectionsForMandate(instanceId, context, rootIf, parts[0], allFds)
+
+ elif kind == "feat" and len(parts) == 3:
+ _mandateId, featureCode, featureInstanceId = parts
+ out[parentKey] = _featureTables(context, rootIf, parentKey, featureInstanceId, featureCode, allFds)
+
+ elif kind == "fdstbl" and len(parts) == 2:
+ featureInstanceId, tableName = parts
+ fieldNames = _resolveTableFieldNames(featureInstanceId, tableName, rootIf)
+ out[parentKey] = _featureTableFields(
+ parentKey, featureInstanceId, tableName, fieldNames, allFds,
+ )
+
+ else:
+ out[parentKey] = []
+ except Exception as exc:
+ logger.exception("Tree children for %s failed: %s", parentKey, exc)
+ out[parentKey] = []
+
+ return out
+
+
+def _reverseService(sourceType: str) -> str:
+ for svc, st in _SERVICE_TO_SOURCE_TYPE.items():
+ if st == sourceType:
+ return svc
+ return sourceType
+
+
+def _resolveTableFieldNames(featureInstanceId: str, tableName: str, rootIf: Any) -> List[str]:
+ """Look up the declared column list for a (featureInstance, tableName)
+ pair via the RBAC catalog data-object metadata. Returns empty list when
+ the catalog has no entry (e.g. wildcard-only feature)."""
+ from modules.security.rbacCatalog import getCatalogService
+ inst = rootIf.getFeatureInstance(featureInstanceId)
+ if not inst:
+ return []
+ catalog = getCatalogService()
+ for obj in catalog.getDataObjects(inst.featureCode) or []:
+ meta = obj.get("meta", {}) if isinstance(obj, dict) else {}
+ if meta.get("table") == tableName:
+ fields = meta.get("fields")
+ if isinstance(fields, list):
+ return [f for f in fields if isinstance(f, str) and f]
+ return []
+ return []
+
+
+# ---------------------------------------------------------------------------
+# Attribute-only refresh: given node keys, return current effective values
+# ---------------------------------------------------------------------------
+
+async def getAttributesForKeys(
+ instanceId: str,
+ keys: List[str],
+ context: Any,
+) -> Dict[str, Dict[str, Any]]:
+ """Return effective attribute values for a list of node keys.
+
+ Used by the frontend after a toggle to refresh only attributes (neutralize,
+ scope, ragIndexEnabled) without reloading the tree structure. For container
+ nodes (personalRoot, mgrp), aggregates child values and returns 'mixed'
+ when children diverge."""
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.datamodels.datamodelDataSource import DataSource
+ from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
+
+ rootIf = getRootInterface()
+ userId = str(context.user.id)
+ allDs = rootIf.db.getRecordset(DataSource, recordFilter={"userId": userId}) or []
+ allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": instanceId}) or []
+
+ result: Dict[str, Dict[str, Any]] = {}
+
+ for key in keys:
+ try:
+ attrs = _resolveAttrsForKey(key, allDs, allFds, instanceId, context, rootIf)
+ if attrs is not None:
+ result[key] = attrs
+ if "mixed" in str(attrs.values()):
+ logger.info("getAttributesForKeys key=%s returned MIXED: %s", key, attrs)
+ except Exception as exc:
+ logger.warning("getAttributesForKeys failed for key=%s: %s", key, exc)
+
+ logger.info("getAttributesForKeys: %d keys requested, %d resolved", len(keys), len(result))
+ return result
+
+
+def _resolveAttrsForKey(
+ key: str,
+ allDs: List[Dict[str, Any]],
+ allFds: List[Dict[str, Any]],
+ instanceId: str,
+ context: Any,
+ rootIf: Any,
+) -> Optional[Dict[str, Any]]:
+ """Resolve effective attributes for a single node key."""
+ if key == _KEY_PERSONAL_ROOT:
+ return _aggregatePersonalRoot(allDs)
+
+ try:
+ kind, parts = _decode(key)
+ except Exception:
+ return None
+
+ if kind == "mgrp" and len(parts) == 1:
+ return _aggregateMandateGroup(parts[0], allFds, instanceId, context, rootIf)
+
+ if kind == "conn" and len(parts) == 1:
+ connId = parts[0]
+ return _aggregateConnection(connId, allDs)
+
+ if kind == "svc" and len(parts) == 2:
+ connId, service = parts
+ sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service)
+ return _effectiveTripletDs(connId, sourceType, "/", allDs)
+
+ if kind == "ds" and len(parts) == 3:
+ connId, sourceType, path = parts
+ return _effectiveTripletDs(connId, sourceType, path, allDs)
+
+ if kind == "feat" and len(parts) == 3:
+ _mandateId, _featureCode, featureInstanceId = parts
+ return _effectiveTripletFds(featureInstanceId, "*", None, allFds)
+
+ if kind == "fdsws" and len(parts) == 2:
+ workspaceInstanceId, _featureCode = parts
+ return _effectiveTripletFds(workspaceInstanceId, "*", None, allFds)
+
+ if kind == "fdstbl" and len(parts) == 2:
+ featureInstanceId, tableName = parts
+ return _effectiveTripletFds(featureInstanceId, tableName, None, allFds)
+
+ if kind == "fdsrec" and len(parts) == 3:
+ featureInstanceId, tableName, recordId = parts
+ return _effectiveTripletFds(featureInstanceId, tableName, {"objectKey": recordId}, allFds)
+
+ if kind == "fdsfld" and len(parts) >= 3:
+ featureInstanceId, tableName = parts[0], parts[1]
+ fieldName = parts[2] if len(parts) > 2 else ""
+ parentFds = None
+ for fds in allFds:
+ if (fds.get("featureInstanceId") == featureInstanceId
+ and (fds.get("tableName") or "") == tableName
+ and fds.get("recordFilter") is None):
+ parentFds = fds
+ break
+ neutralizeFields = (parentFds.get("neutralizeFields") or []) if parentFds else []
+ return {"effectiveNeutralize": fieldName in neutralizeFields}
+
+ return None
+
+
+def _aggregateConnection(connId: str, allDs: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """Aggregate effective values for a connection node.
+
+ If the connection has an authority-level DS record (path="/"), use the
+ standard aggregate mode on it (which already handles subtree correctly).
+ Otherwise compute effective values for each child DS using walk mode and
+ aggregate them manually."""
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ getEffectiveFlag, _AUTHORITY_SOURCE_TYPES,
+ )
+ connRecords = [d for d in allDs if d.get("connectionId") == connId]
+ if not connRecords:
+ return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False}
+
+ rootRec = None
+ for r in connRecords:
+ st = r.get("sourceType", "")
+ if st in _AUTHORITY_SOURCE_TYPES and _normalisePath(r.get("path", "")) == "/":
+ rootRec = r
+ break
+
+ if rootRec:
+ return _effectiveTripletDs(connId, rootRec.get("sourceType", ""), "/", allDs)
+
+ neutralizeVals = set()
+ scopeVals = set()
+ ragVals = set()
+ for r in connRecords:
+ neutralizeVals.add(getEffectiveFlag(r, "neutralize", allDs, mode="walk"))
+ scopeVals.add(getEffectiveFlag(r, "scope", allDs, mode="walk"))
+ ragVals.add(getEffectiveFlag(r, "ragIndexEnabled", allDs, mode="walk"))
+ return {
+ "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False),
+ "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"),
+ "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False),
+ }
+
+
+def _aggregatePersonalRoot(allDs: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """Aggregate effective values across all personal DS records.
+
+ Uses getEffectiveFlag in aggregate mode on each connection-root record.
+ If no root records exist, aggregates walk-effective values of all records."""
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
+ getEffectiveFlag, _AUTHORITY_SOURCE_TYPES,
+ )
+ if not allDs:
+ return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False}
+
+ rootRecords = [
+ d for d in allDs
+ if d.get("sourceType", "") in _AUTHORITY_SOURCE_TYPES
+ and _normalisePath(d.get("path", "")) == "/"
+ ]
+ targets = rootRecords if rootRecords else allDs
+
+ neutralizeVals = set()
+ scopeVals = set()
+ ragVals = set()
+ for ds in targets:
+ neutralizeVals.add(getEffectiveFlag(ds, "neutralize", allDs, mode="aggregate"))
+ scopeVals.add(getEffectiveFlag(ds, "scope", allDs, mode="aggregate"))
+ ragVals.add(getEffectiveFlag(ds, "ragIndexEnabled", allDs, mode="aggregate"))
+ return {
+ "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False),
+ "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"),
+ "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False),
+ }
+
+
+def _aggregateMandateGroup(
+ mandateId: str,
+ allFds: List[Dict[str, Any]],
+ instanceId: str,
+ context: Any,
+ rootIf: Any,
+) -> Dict[str, Any]:
+ """Aggregate effective values across FDS records belonging to this mandate group.
+
+ Uses getEffectiveFlagFds in aggregate mode on each workspace-level FDS
+ (tableName="*") that belongs to the given mandateId. This correctly resolves
+ inherited values from the full FDS hierarchy."""
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
+
+ groupFds = [f for f in allFds if f.get("mandateId") == mandateId]
+ workspaceLevelFds = [f for f in groupFds if (f.get("tableName") or "") == "*"]
+ targets = workspaceLevelFds if workspaceLevelFds else groupFds
+
+ if not targets:
+ return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False}
+
+ neutralizeVals = set()
+ scopeVals = set()
+ ragVals = set()
+ for fds in targets:
+ neutralizeVals.add(getEffectiveFlagFds(fds, "neutralize", allFds, mode="aggregate"))
+ scopeVals.add(getEffectiveFlagFds(fds, "scope", allFds, mode="aggregate"))
+ ragVals.add(getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate"))
+ return {
+ "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False),
+ "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"),
+ "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False),
+ }
diff --git a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py
index 00180c9f..64a0019c 100644
--- a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py
+++ b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py
@@ -3,9 +3,15 @@
"""Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled, scope).
Three-state flags allow tree elements to either set an explicit value or
-inherit the value from their nearest ancestor in the path hierarchy. The
-walker (RAG/Neutralize) and routes resolve the *effective* value; the cascade
-helper resets explicit descendant values when a parent is toggled.
+inherit the value from their nearest ancestor in the path hierarchy.
+
+Modes:
+ - 'walk' (default): resolves the *concrete* effective value per-item
+ (never returns 'mixed'). Used by backend consumers (RAG walker,
+ neutralization pipeline, scope filter, etc.).
+ - 'aggregate': resolves the *display* effective value per-item. If the
+ item has descendants with differing walk-effective values, returns
+ 'mixed'. Used by listing endpoints and PATCH responses for the UI.
Path-traversal rules:
- A DataSource is identified by `(connectionId, sourceType, path)`.
@@ -17,11 +23,12 @@ Path-traversal rules:
"""
import logging
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
logger = logging.getLogger(__name__)
_INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope")
+_INHERITABLE_FDS_FLAGS = ("neutralize", "ragIndexEnabled", "scope")
# Connection-root DataSources carry the authority as their sourceType
# (e.g. 'msft', 'google'). They sit one level above all service DataSources
@@ -29,6 +36,12 @@ _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope")
# cross sourceType boundaries — but ONLY from these authority roots.
_AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"})
+Mode = Literal["walk", "aggregate"]
+
+
+# ---------------------------------------------------------------------------
+# Internal helpers
+# ---------------------------------------------------------------------------
def _normalisePath(path: Optional[str]) -> str:
"""Normalize a DataSource path to '/'-prefixed, no trailing slash (except root)."""
@@ -49,10 +62,7 @@ def _flagDefault(flag: str) -> Any:
def _isExplicit(value: Any) -> bool:
- """A flag value is explicit when it is not None.
-
- Note: legacy rows may carry empty-string scope; treat as inherit too.
- """
+ """A flag value is explicit when it is not None/empty-string."""
if value is None:
return False
if isinstance(value, str) and value == "":
@@ -66,6 +76,21 @@ def _getRecordValue(rec: Any, key: str) -> Any:
return getattr(rec, key, None)
+def _isAncestorPath(ancestor: str, descendant: str) -> bool:
+ """True iff `ancestor` is a strict path-prefix of `descendant`."""
+ if ancestor == descendant:
+ return False
+ if ancestor == "/":
+ return descendant != "/"
+ return descendant.startswith(ancestor + "/")
+
+
+def _pathDepth(path: str) -> int:
+ if path == "/":
+ return 0
+ return path.count("/")
+
+
def _findAncestorChain(
rec: Dict[str, Any],
allDs: Iterable[Dict[str, Any]],
@@ -74,15 +99,13 @@ def _findAncestorChain(
ordered nearest-first.
Two ancestor relations are merged:
- 1) **same-sourceType path-ancestor** — strict path-prefix within the
- same service tree (sharepointFolder, gmailFolder, ...).
- 2) **connection-root ancestor** — a DS with `path='/'` and
- `sourceType` ∈ authority set (msft, google, ...) is the parent of
- every other DS in that connection regardless of sourceType, so a
- toggle on the connection node propagates to all services beneath.
+ 1) same-sourceType path-ancestor — strict path-prefix within the
+ same service tree.
+ 2) connection-root ancestor — a DS with `path='/'` and
+ `sourceType` in authority set is the parent of every other DS
+ in that connection regardless of sourceType.
- The connection-root is always the most distant ancestor and therefore
- sorts after any same-sourceType ancestors.
+ The connection-root is always the most distant ancestor.
"""
recPath = _normalisePath(_getRecordValue(rec, "path"))
recSourceType = _getRecordValue(rec, "sourceType")
@@ -114,36 +137,89 @@ def _findAncestorChain(
return chain
-def _isAncestorPath(ancestor: str, descendant: str) -> bool:
- """True iff `ancestor` is a strict path-prefix of `descendant`.
+def _isDescendantDs(parentRec: Dict[str, Any], candidate: Dict[str, Any]) -> bool:
+ """True iff `candidate` is a descendant of `parentRec` in the DS hierarchy."""
+ parentSourceType = _getRecordValue(parentRec, "sourceType")
+ parentPath = _normalisePath(_getRecordValue(parentRec, "path"))
+ parentConnectionId = _getRecordValue(parentRec, "connectionId")
+ parentId = _getRecordValue(parentRec, "id")
- '/' is ancestor of every non-root path. For non-root prefixes, the
- descendant must continue with '/' so '/foo' isn't treated as ancestor of
- '/foobar'.
- """
- if ancestor == descendant:
+ candId = _getRecordValue(candidate, "id")
+ if candId == parentId:
+ return False
+ if _getRecordValue(candidate, "connectionId") != parentConnectionId:
return False
- if ancestor == "/":
- return descendant != "/"
- return descendant.startswith(ancestor + "/")
+ candSourceType = _getRecordValue(candidate, "sourceType")
+ candPath = _normalisePath(_getRecordValue(candidate, "path"))
+
+ parentIsConnectionRoot = (
+ parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/"
+ )
+ if parentIsConnectionRoot:
+ return True
+ if candSourceType != parentSourceType:
+ return False
+ return _isAncestorPath(parentPath, candPath)
+
+
+# ---------------------------------------------------------------------------
+# DataSource: getEffectiveFlag
+# ---------------------------------------------------------------------------
def getEffectiveFlag(
rec: Dict[str, Any],
flag: str,
sameConnectionDs: Iterable[Dict[str, Any]],
+ mode: Mode = "walk",
) -> Any:
"""Resolve the effective value of a flag via path-traversal.
- Order: own value (if explicit) → nearest ancestor with explicit value →
- static default (`False` or `'personal'`).
+ mode='walk': own explicit → nearest ancestor explicit → default.
+ Always returns a concrete value (never 'mixed').
+ mode='aggregate': same as walk for leaf value, but if the item has
+ descendants whose walk-effective values differ from
+ each other, returns 'mixed'.
"""
if flag not in _INHERITABLE_FLAGS:
raise ValueError(f"Unknown inheritable flag: {flag}")
+
+ allDs = list(sameConnectionDs)
+
+ walkValue = _resolveWalkValue(rec, flag, allDs)
+
+ if mode == "walk":
+ return walkValue
+
+ # mode == 'aggregate': check subtree for heterogeneous effective values
+ descendants = [d for d in allDs if _isDescendantDs(rec, d)]
+ if not descendants:
+ return walkValue
+
+ subtreeValues = set()
+ subtreeValues.add(_normaliseForComparison(walkValue))
+ for desc in descendants:
+ descEffective = _resolveWalkValue(desc, flag, allDs)
+ subtreeValues.add(_normaliseForComparison(descEffective))
+ if len(subtreeValues) > 1:
+ recId = _getRecordValue(rec, "id")
+ descId = _getRecordValue(desc, "id")
+ descOwnVal = _getRecordValue(desc, flag)
+ logger.info(
+ "DS aggregate MIXED for rec=%s flag=%s: walkValue=%s, "
+ "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s",
+ recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues,
+ )
+ return "mixed"
+ return walkValue
+
+
+def _resolveWalkValue(rec: Dict[str, Any], flag: str, allDs: List[Dict[str, Any]]) -> Any:
+ """Core walk resolution: own explicit → ancestor chain → default."""
own = _getRecordValue(rec, flag)
if _isExplicit(own):
return own
- chain = _findAncestorChain(rec, sameConnectionDs)
+ chain = _findAncestorChain(rec, allDs)
for ancestor in chain:
ancestorVal = _getRecordValue(ancestor, flag)
if _isExplicit(ancestorVal):
@@ -151,69 +227,112 @@ def getEffectiveFlag(
return _flagDefault(flag)
+def _normaliseForComparison(value: Any) -> Any:
+ """Normalize values for set-comparison (bool as int to avoid hash issues)."""
+ if isinstance(value, bool):
+ return int(value)
+ return value
+
+
+# ---------------------------------------------------------------------------
+# DataSource: cascadeResetDescendants (bottom-up)
+# ---------------------------------------------------------------------------
+
def cascadeResetDescendants(
rootIf: Any,
parentRec: Dict[str, Any],
flag: str,
-) -> int:
+) -> List[str]:
"""Reset all explicit descendant values of `flag` to NULL (= inherit).
- Descendant relation mirrors `_findAncestorChain`:
- - Connection-root (`path='/'` AND `sourceType` ∈ authorities) is parent
- of every other DS in that connection (cross-sourceType cascade).
- - Otherwise: same-sourceType strict path-descendants only.
+ Reset order: bottom-up (deepest first) for crash safety.
+ The parent itself is NOT modified here — the caller sets the master value
+ after this function returns.
- Only the targeted `flag` is reset; other flags on the descendant are
- untouched.
-
- Returns the number of records updated.
+ Returns list of reset record IDs in bottom-up order.
"""
if flag not in _INHERITABLE_FLAGS:
raise ValueError(f"Unknown inheritable flag: {flag}")
from modules.datamodels.datamodelDataSource import DataSource
connectionId = _getRecordValue(parentRec, "connectionId")
- parentSourceType = _getRecordValue(parentRec, "sourceType")
- parentPath = _normalisePath(_getRecordValue(parentRec, "path"))
parentId = _getRecordValue(parentRec, "id")
- if not connectionId or not parentSourceType:
- return 0
-
- parentIsConnectionRoot = (
- parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/"
- )
+ if not connectionId:
+ return []
siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
- affected = 0
+
+ toReset: List[Tuple[int, str]] = []
for sib in siblings:
- sibId = _getRecordValue(sib, "id")
- if sibId == parentId:
+ if not _isDescendantDs(parentRec, sib):
continue
- sibSourceType = _getRecordValue(sib, "sourceType")
- sibPath = _normalisePath(_getRecordValue(sib, "path"))
- if parentIsConnectionRoot:
- # Connection-root resets everything else under this connection.
- pass
- else:
- if sibSourceType != parentSourceType:
- continue
- if not _isAncestorPath(parentPath, sibPath):
- continue
sibVal = _getRecordValue(sib, flag)
if not _isExplicit(sibVal):
continue
+ sibId = _getRecordValue(sib, "id")
+ sibPath = _normalisePath(_getRecordValue(sib, "path"))
+ toReset.append((_pathDepth(sibPath), sibId))
+
+ # Sort deepest first (bottom-up)
+ toReset.sort(key=lambda x: x[0], reverse=True)
+
+ resetIds: List[str] = []
+ for _, sibId in toReset:
try:
rootIf.db.recordModify(DataSource, sibId, {flag: None})
- affected += 1
+ resetIds.append(sibId)
except Exception as exc:
logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc)
- if affected:
- logger.info(
- "Cascade-reset %s on %d descendants of DataSource (connectionId=%s, sourceType=%s, path=%s, connectionRoot=%s)",
- flag, affected, connectionId, parentSourceType, parentPath, parentIsConnectionRoot,
- )
- return affected
+ if resetIds:
+ logger.info(
+ "Cascade-reset %s on %d descendants of DataSource %s (bottom-up)",
+ flag, len(resetIds), parentId,
+ )
+ return resetIds
+
+
+# ---------------------------------------------------------------------------
+# DataSource: collectAncestorChain (for updatedAncestors in PATCH response)
+# ---------------------------------------------------------------------------
+
+def collectAncestorChain(
+ rec: Dict[str, Any],
+ sameConnectionDs: Iterable[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Return ancestor chain of `rec` (nearest-first), same as internal helper.
+
+ Exposed for PATCH endpoints to compute updatedAncestors.
+ """
+ return _findAncestorChain(rec, sameConnectionDs)
+
+
+# ---------------------------------------------------------------------------
+# DataSource: buildEffectiveByConnection
+# ---------------------------------------------------------------------------
+
+def buildEffectiveByConnection(
+ dataSources: Iterable[Dict[str, Any]],
+ flag: str,
+ mode: Mode = "walk",
+) -> Dict[str, Any]:
+ """Pre-compute the effective value of `flag` for every DataSource id.
+
+ Uses the specified mode. O(N^2) worst case but N is bounded per connection.
+ """
+ if flag not in _INHERITABLE_FLAGS:
+ raise ValueError(f"Unknown inheritable flag: {flag}")
+ allDs = list(dataSources)
+ out: Dict[str, Any] = {}
+ for rec in allDs:
+ recId = _getRecordValue(rec, "id")
+ out[recId] = getEffectiveFlag(rec, flag, allDs, mode=mode)
+ return out
+
+
+# ---------------------------------------------------------------------------
+# FeatureDataSource helpers
+# ---------------------------------------------------------------------------
def _fdsClassify(fds: Dict[str, Any]) -> str:
"""Return 'workspace' | 'table' | 'record' based on the FDS identifier shape."""
@@ -229,14 +348,14 @@ def _fdsClassify(fds: Dict[str, Any]) -> str:
def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool:
"""Return True iff `parent` FDS is a strict ancestor of `child` FDS.
- Hierarchy within one `workspaceInstanceId`:
- workspace-wildcard (tableName='*') → table-wildcard (tableName='X', !recordFilter)
- → record-fds (tableName='X', recordFilter.id=...)
- table-wildcard (tableName='X') → record-fds (tableName='X', recordFilter.id=...)
+ Hierarchy within one featureInstanceId (allFds is already scoped to
+ a single workspace):
+ feature-wildcard (tableName='*') -> table-wildcard / record-fds
+ table-wildcard (tableName='X') -> record-fds (tableName='X')
"""
- parentWsId = _getRecordValue(parent, "workspaceInstanceId")
- childWsId = _getRecordValue(child, "workspaceInstanceId")
- if not parentWsId or parentWsId != childWsId:
+ parentFiId = _getRecordValue(parent, "featureInstanceId")
+ childFiId = _getRecordValue(child, "featureInstanceId")
+ if not parentFiId or parentFiId != childFiId:
return False
if _getRecordValue(parent, "id") == _getRecordValue(child, "id"):
return False
@@ -251,23 +370,68 @@ def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool:
return False
+def _fdsDepth(fds: Dict[str, Any]) -> int:
+ kind = _fdsClassify(fds)
+ if kind == "workspace":
+ return 0
+ if kind == "table":
+ return 1
+ return 2
+
+
+# ---------------------------------------------------------------------------
+# FeatureDataSource: getEffectiveFlagFds
+# ---------------------------------------------------------------------------
+
def getEffectiveFlagFds(
rec: Dict[str, Any],
flag: str,
sameWorkspaceFds: Iterable[Dict[str, Any]],
+ mode: Mode = "walk",
) -> Any:
"""Resolve effective value of a FeatureDataSource flag.
- Order: own (if explicit) → table-wildcard (if explicit) →
- workspace-wildcard (if explicit) → static default.
+ mode='walk': own explicit -> table-wildcard -> workspace-wildcard -> default.
+ mode='aggregate': same but returns 'mixed' if descendants diverge.
"""
- if flag not in ("neutralize", "scope"):
+ if flag not in _INHERITABLE_FDS_FLAGS:
raise ValueError(f"Unknown inheritable FDS flag: {flag}")
+
+ allFds = list(sameWorkspaceFds)
+ walkValue = _resolveWalkValueFds(rec, flag, allFds)
+
+ if mode == "walk":
+ return walkValue
+
+ # mode == 'aggregate'
+ descendants = [f for f in allFds if _fdsIsAncestor(rec, f)]
+ if not descendants:
+ return walkValue
+
+ subtreeValues = set()
+ subtreeValues.add(_normaliseForComparison(walkValue))
+ for desc in descendants:
+ descEffective = _resolveWalkValueFds(desc, flag, allFds)
+ subtreeValues.add(_normaliseForComparison(descEffective))
+ if len(subtreeValues) > 1:
+ recId = _getRecordValue(rec, "id")
+ descId = _getRecordValue(desc, "id")
+ descOwnVal = _getRecordValue(desc, flag)
+ logger.info(
+ "FDS aggregate MIXED for rec=%s flag=%s: walkValue=%s, "
+ "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s",
+ recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues,
+ )
+ return "mixed"
+ return walkValue
+
+
+def _resolveWalkValueFds(rec: Dict[str, Any], flag: str, allFds: List[Dict[str, Any]]) -> Any:
+ """Core walk resolution for FDS."""
own = _getRecordValue(rec, flag)
if _isExplicit(own):
return own
- workspaceFds: List[Dict[str, Any]] = list(sameWorkspaceFds)
- ancestors = [a for a in workspaceFds if _fdsIsAncestor(a, rec)]
+ ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)]
ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1)
for ancestor in ancestors:
val = _getRecordValue(ancestor, flag)
@@ -276,27 +440,32 @@ def getEffectiveFlagFds(
return _flagDefault(flag)
+# ---------------------------------------------------------------------------
+# FeatureDataSource: cascadeResetDescendantsFds (bottom-up)
+# ---------------------------------------------------------------------------
+
def cascadeResetDescendantsFds(
rootIf: Any,
parentRec: Dict[str, Any],
flag: str,
-) -> int:
+) -> List[str]:
"""Reset explicit `flag` to NULL on every descendant FDS of `parentRec`.
- Only the targeted flag is reset; other flags on descendants are untouched.
- Returns the number of records updated.
+ Reset order: bottom-up (deepest first) for crash safety.
+ Returns list of reset record IDs in bottom-up order.
"""
- if flag not in ("neutralize", "scope"):
+ if flag not in _INHERITABLE_FDS_FLAGS:
raise ValueError(f"Unknown inheritable FDS flag: {flag}")
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
workspaceInstanceId = _getRecordValue(parentRec, "workspaceInstanceId")
if not workspaceInstanceId:
- return 0
+ return []
siblings = rootIf.db.getRecordset(
FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId}
)
- affected = 0
+
+ toReset: List[Tuple[int, str]] = []
for sib in siblings:
if not _fdsIsAncestor(parentRec, sib):
continue
@@ -304,39 +473,159 @@ def cascadeResetDescendantsFds(
if not _isExplicit(sibVal):
continue
sibId = _getRecordValue(sib, "id")
+ toReset.append((_fdsDepth(sib), sibId))
+
+ # Sort deepest first (bottom-up)
+ toReset.sort(key=lambda x: x[0], reverse=True)
+
+ resetIds: List[str] = []
+ for _, sibId in toReset:
try:
rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None})
- affected += 1
+ resetIds.append(sibId)
except Exception as exc:
logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc)
- if affected:
+
+ if resetIds:
logger.info(
- "FDS cascade-reset %s on %d descendants of FDS (workspaceInstanceId=%s, kind=%s)",
- flag, affected, workspaceInstanceId, _fdsClassify(parentRec),
+ "FDS cascade-reset %s on %d descendants of FDS %s (bottom-up)",
+ flag, len(resetIds), _getRecordValue(parentRec, "id"),
)
- return affected
+ return resetIds
-def buildEffectiveByConnection(
- dataSources: Iterable[Dict[str, Any]],
- flag: str,
-) -> Dict[str, Any]:
- """Pre-compute the effective value of `flag` for every DataSource id.
+# ---------------------------------------------------------------------------
+# FeatureDataSource: collectAncestorChainFds
+# ---------------------------------------------------------------------------
- Useful for batch operations (walker, route DTOs) that touch many records
- at once. O(N²) in the worst case but N is bounded per connection.
+def collectAncestorChainFds(
+ rec: Dict[str, Any],
+ sameWorkspaceFds: Iterable[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """Return ancestor chain of `rec` FDS (nearest-first).
+
+ Exposed for PATCH endpoints to compute updatedAncestors.
"""
- if flag not in _INHERITABLE_FLAGS:
- raise ValueError(f"Unknown inheritable flag: {flag}")
- bySourceType: Dict[Tuple[str, str], List[Dict[str, Any]]] = {}
- for ds in dataSources:
- connId = _getRecordValue(ds, "connectionId") or ""
- srcType = _getRecordValue(ds, "sourceType") or ""
- bySourceType.setdefault((connId, srcType), []).append(ds)
+ allFds = list(sameWorkspaceFds)
+ ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)]
+ ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1)
+ return ancestors
+
+# ---------------------------------------------------------------------------
+# FeatureDataSource: buildEffectiveByWorkspaceFds
+# ---------------------------------------------------------------------------
+
+def buildEffectiveByWorkspaceFds(
+ fdses: Iterable[Dict[str, Any]],
+ flag: str,
+ mode: Mode = "walk",
+) -> Dict[str, Any]:
+ """Pre-compute the effective value of `flag` for every FDS id."""
+ if flag not in _INHERITABLE_FDS_FLAGS:
+ raise ValueError(f"Unknown inheritable FDS flag: {flag}")
+ allFds = list(fdses)
out: Dict[str, Any] = {}
- for group in bySourceType.values():
- for rec in group:
- recId = _getRecordValue(rec, "id")
- out[recId] = getEffectiveFlag(rec, flag, group)
+ for rec in allFds:
+ recId = _getRecordValue(rec, "id")
+ out[recId] = getEffectiveFlagFds(rec, flag, allFds, mode=mode)
return out
+
+
+# ---------------------------------------------------------------------------
+# Bulk resolve: effective flags for arbitrary paths (even without DB record)
+# ---------------------------------------------------------------------------
+
+def resolveEffectiveForPath(
+ connectionId: str,
+ sourceType: str,
+ path: str,
+ allDs: List[Dict[str, Any]],
+ mode: Mode = "aggregate",
+) -> Dict[str, Any]:
+ """Resolve effective flags for ANY (connectionId, sourceType, path) tuple.
+
+ Works whether or not a DataSource record exists for this exact path.
+ Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled.
+ """
+ normPath = _normalisePath(path)
+ exactRecord = None
+ for ds in allDs:
+ if (
+ _getRecordValue(ds, "connectionId") == connectionId
+ and _getRecordValue(ds, "sourceType") == sourceType
+ and _normalisePath(_getRecordValue(ds, "path")) == normPath
+ ):
+ exactRecord = ds
+ break
+
+ if exactRecord:
+ return {
+ "effectiveNeutralize": getEffectiveFlag(exactRecord, "neutralize", allDs, mode=mode),
+ "effectiveScope": getEffectiveFlag(exactRecord, "scope", allDs, mode=mode),
+ "effectiveRagIndexEnabled": getEffectiveFlag(exactRecord, "ragIndexEnabled", allDs, mode=mode),
+ }
+
+ virtualRec = {
+ "id": "__virtual__",
+ "connectionId": connectionId,
+ "sourceType": sourceType,
+ "path": normPath,
+ "neutralize": None,
+ "scope": None,
+ "ragIndexEnabled": None,
+ }
+ return {
+ "effectiveNeutralize": _resolveWalkValue(virtualRec, "neutralize", allDs),
+ "effectiveScope": _resolveWalkValue(virtualRec, "scope", allDs),
+ "effectiveRagIndexEnabled": _resolveWalkValue(virtualRec, "ragIndexEnabled", allDs),
+ }
+
+
+def resolveEffectiveForFds(
+ featureInstanceId: str,
+ tableName: str,
+ recordFilter: Optional[Dict[str, str]],
+ allFds: List[Dict[str, Any]],
+ mode: Mode = "aggregate",
+) -> Dict[str, Any]:
+ """Resolve effective flags for ANY FDS tuple (even without DB record).
+
+ `allFds` is pre-scoped to a single workspace (loaded with
+ workspaceInstanceId filter). Within that set, the coordinate is
+ featureInstanceId + tableName + recordFilter.
+
+ Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled.
+ """
+ exactRecord = None
+ for fds in allFds:
+ if _getRecordValue(fds, "featureInstanceId") != featureInstanceId:
+ continue
+ if (_getRecordValue(fds, "tableName") or "") != tableName:
+ continue
+ fdsFilter = _getRecordValue(fds, "recordFilter")
+ if fdsFilter == recordFilter:
+ exactRecord = fds
+ break
+
+ if exactRecord:
+ return {
+ "effectiveNeutralize": getEffectiveFlagFds(exactRecord, "neutralize", allFds, mode=mode),
+ "effectiveScope": getEffectiveFlagFds(exactRecord, "scope", allFds, mode=mode),
+ "effectiveRagIndexEnabled": getEffectiveFlagFds(exactRecord, "ragIndexEnabled", allFds, mode=mode),
+ }
+
+ virtualRec = {
+ "id": "__virtual__",
+ "featureInstanceId": featureInstanceId,
+ "tableName": tableName,
+ "recordFilter": recordFilter,
+ "neutralize": None,
+ "scope": None,
+ "ragIndexEnabled": None,
+ }
+ return {
+ "effectiveNeutralize": _resolveWalkValueFds(virtualRec, "neutralize", allFds),
+ "effectiveScope": _resolveWalkValueFds(virtualRec, "scope", allFds),
+ "effectiveRagIndexEnabled": _resolveWalkValueFds(virtualRec, "ragIndexEnabled", allFds),
+ }
diff --git a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py
index 6698e164..01c585d8 100644
--- a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py
+++ b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py
@@ -147,7 +147,7 @@ class KnowledgeService:
else getattr(existing, "status", "")
) or ""
if existingMeta.get("hash") == contentHash and existingStatus == "indexed":
- logger.info(
+ logger.debug(
"ingestion.skipped.duplicate sourceKind=%s sourceId=%s hash=%s",
job.sourceKind, job.sourceId, contentHash[:12],
extra={
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py
index 618a9965..be059eef 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py
@@ -431,6 +431,15 @@ def registerKnowledgeIngestionConsumer() -> None:
callbackRegistry.register("connection.established", _onConnectionEstablished)
callbackRegistry.register("connection.revoked", _onConnectionRevoked)
registerJobHandler(BOOTSTRAP_JOB_TYPE, _bootstrapJobHandler)
+
+ from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import (
+ FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler,
+ )
+ registerJobHandler(FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler)
+
registerDailyResyncScheduler()
_registered = True
- logger.info("KnowledgeIngestionConsumer registered (established/revoked + %s handler + daily resync)", BOOTSTRAP_JOB_TYPE)
+ logger.info(
+ "KnowledgeIngestionConsumer registered (established/revoked + %s + %s handler + daily resync)",
+ BOOTSTRAP_JOB_TYPE, FEATURE_BOOTSTRAP_JOB_TYPE,
+ )
diff --git a/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py b/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py
new file mode 100644
index 00000000..aa81d929
--- /dev/null
+++ b/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py
@@ -0,0 +1,289 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""Feature-data RAG bootstrap: indexes FeatureDataSource rows into the knowledge store.
+
+Analogous to connection.bootstrap for external connections (Google, Microsoft),
+this handler reads FeatureDataSource records with ragIndexEnabled=True, queries
+the underlying feature tables via FeatureDataProvider, serialises each row into
+text, and feeds it through KnowledgeService.requestIngestion so the data
+appears in ContentChunk embeddings for semantic RAG search.
+
+Job type: ``feature.bootstrap``
+Payload: ``{"workspaceInstanceId": "...", "featureDataSourceIds": [...] (optional)}``
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+from typing import Any, Dict, List, Optional
+
+logger = logging.getLogger(__name__)
+
+FEATURE_BOOTSTRAP_JOB_TYPE = "feature.bootstrap"
+
+
+def _loadRagEnabledFds(workspaceInstanceId: str, featureDataSourceIds: Optional[List[str]] = None):
+ """Load FeatureDataSource rows whose effective ragIndexEnabled is True.
+
+ Returns dicts with resolved flags so downstream code can read them directly.
+ """
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
+ from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
+
+ rootIf = getRootInterface()
+ allFds = rootIf.db.getRecordset(
+ FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId}
+ )
+ resolved = []
+ for fds in allFds:
+ tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or ""
+ fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or ""
+ if tblName == "*" or not tblName or not fCode:
+ continue
+ effRag = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate")
+ if effRag is not True:
+ continue
+ row = dict(fds) if isinstance(fds, dict) else {**fds.__dict__}
+ row["_effectiveNeutralize"] = getEffectiveFlagFds(fds, "neutralize", allFds, mode="aggregate")
+ row["_effectiveScope"] = getEffectiveFlagFds(fds, "scope", allFds, mode="aggregate") or "featureInstance"
+ row["ragIndexEnabled"] = True
+ resolved.append(row)
+
+ if featureDataSourceIds:
+ idSet = set(featureDataSourceIds)
+ resolved = [r for r in resolved if r.get("id") in idSet]
+ return resolved
+
+
+def _serializeRowToText(row: Dict[str, Any], neutralizeFields: Optional[List[str]] = None) -> str:
+ """Convert a feature-table row into readable text for embedding.
+
+ Skips internal fields (starting with '_' or 'sys') and produces
+ ``key: value`` lines that embed well semantically.
+ """
+ neutralizeSet = set(neutralizeFields or [])
+ lines = []
+ for key, value in row.items():
+ if key.startswith("_") or key.startswith("sys"):
+ continue
+ if key == "id":
+ continue
+ if value is None or value == "" or value == []:
+ continue
+ if key in neutralizeSet:
+ value = "[REDACTED]"
+ elif isinstance(value, (dict, list)):
+ value = json.dumps(value, ensure_ascii=False, default=str)
+ else:
+ value = str(value)
+ lines.append(f"{key}: {value}")
+ return "\n".join(lines)
+
+
+def _getFeatureDbConnector(featureCode: str):
+ """Create a lightweight DB connector to the feature database."""
+ from modules.connectors.connectorDbPostgre import DatabaseConnector
+ from modules.shared.configuration import APP_CONFIG
+
+ dbName = f"poweron_{featureCode.lower()}"
+ return DatabaseConnector(
+ dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
+ dbDatabase=dbName,
+ dbUser=APP_CONFIG.get("DB_USER"),
+ dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ dbPort=int(APP_CONFIG.get("DB_PORT", 5432)),
+ userId="system.feature_bootstrap",
+ )
+
+
+async def _featureBootstrapHandler(
+ job: Dict[str, Any],
+ progressCb,
+) -> Dict[str, Any]:
+ """Walk RAG-enabled FeatureDataSources and index their rows."""
+ payload = job.get("payload") or {}
+ workspaceInstanceId = payload.get("workspaceInstanceId")
+ featureDataSourceIds = payload.get("featureDataSourceIds")
+ if not workspaceInstanceId:
+ raise ValueError("feature.bootstrap requires payload.workspaceInstanceId")
+
+ progressCb(5, messageKey="Feature-Datenquellen werden geladen...")
+
+ fdsList = _loadRagEnabledFds(workspaceInstanceId, featureDataSourceIds)
+ if not fdsList:
+ logger.info(
+ "feature.bootstrap.skipped — no rag-enabled FDS for workspace %s",
+ workspaceInstanceId,
+ )
+ return {"workspaceInstanceId": workspaceInstanceId, "skipped": True, "reason": "no_rag_enabled_fds"}
+
+ from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
+ from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob
+ from modules.serviceCenter.context import ServiceCenterContext
+ from modules.serviceCenter import getService
+ from modules.security.rootAccess import getRootUser
+
+ totalIndexed = 0
+ totalSkipped = 0
+ totalFailed = 0
+ fdsResults = []
+
+ for fdsIdx, fds in enumerate(fdsList):
+ fdsId = fds.get("id", "")
+ featureCode = fds.get("featureCode", "")
+ tableName = fds.get("tableName", "")
+ featureInstanceId = fds.get("featureInstanceId", "")
+ mandateId = fds.get("mandateId", "")
+ neutralizeFields = fds.get("neutralizeFields") or []
+ recordFilter = fds.get("recordFilter") or {}
+ effectiveScope = fds.get("_effectiveScope", "featureInstance")
+ effectiveNeutralize = bool(fds.get("_effectiveNeutralize", False))
+
+ progressPct = 5 + int(90 * fdsIdx / len(fdsList))
+ progressCb(
+ progressPct,
+ messageKey="Indexiere {table} ({n}/{total})...",
+ messageParams={"table": tableName, "n": fdsIdx + 1, "total": len(fdsList)},
+ )
+
+ if not featureCode or not tableName or not featureInstanceId:
+ logger.warning("feature.bootstrap: skipping FDS %s — missing featureCode/tableName/fiId", fdsId)
+ continue
+
+ try:
+ dbConnector = _getFeatureDbConnector(featureCode)
+ provider = FeatureDataProvider(dbConnector)
+
+ rootUser = getRootUser()
+ ctx = ServiceCenterContext(
+ user=rootUser,
+ mandate_id=mandateId,
+ feature_instance_id=workspaceInstanceId,
+ )
+ knowledgeService = getService("knowledge", ctx)
+
+ extraFilters = [
+ {"field": k, "op": "=", "value": v}
+ for k, v in recordFilter.items()
+ ] if recordFilter else None
+
+ batchSize = 200
+ offset = 0
+ fdsIndexed = 0
+ fdsSkipped = 0
+ fdsFailed = 0
+
+ while True:
+ result = provider.browseTable(
+ tableName=tableName,
+ featureInstanceId=featureInstanceId,
+ mandateId=mandateId,
+ limit=batchSize,
+ offset=offset,
+ extraFilters=extraFilters,
+ )
+ rows = result.get("rows", [])
+ if not rows:
+ break
+
+ for row in rows:
+ rowId = row.get("id", "")
+ if not rowId:
+ continue
+
+ textContent = _serializeRowToText(row, neutralizeFields if effectiveNeutralize else None)
+ if not textContent.strip():
+ fdsSkipped += 1
+ continue
+
+ contentVersion = str(row.get("sysUpdatedAt") or row.get("sysCreatedAt") or "")
+
+ ingestionJob = IngestionJob(
+ sourceKind="feature_record",
+ sourceId=f"{workspaceInstanceId}:{tableName}:{rowId}",
+ fileName=f"{tableName}-{rowId}",
+ mimeType="application/vnd.poweron.feature-record+json",
+ userId=fds.get("userId") or "system",
+ featureInstanceId=workspaceInstanceId,
+ mandateId=mandateId,
+ contentObjects=[{
+ "contentType": "text",
+ "data": textContent,
+ "contextRef": {
+ "table": tableName,
+ "featureCode": featureCode,
+ "featureInstanceId": featureInstanceId,
+ "rowId": rowId,
+ },
+ "contentObjectId": f"{tableName}:{rowId}",
+ }],
+ structure={"sourceTable": tableName, "featureCode": featureCode},
+ contentVersion=contentVersion,
+ provenance={
+ "featureDataSourceId": fdsId,
+ "tableName": tableName,
+ "featureCode": featureCode,
+ "featureInstanceId": featureInstanceId,
+ },
+ neutralize=effectiveNeutralize,
+ )
+
+ try:
+ handle = await knowledgeService.requestIngestion(ingestionJob)
+ if handle.status == "failed":
+ fdsFailed += 1
+ logger.warning(
+ "feature.bootstrap: ingestion failed fds=%s table=%s row=%s error=%s",
+ fdsId, tableName, rowId, handle.error,
+ )
+ elif handle.status == "duplicate":
+ fdsSkipped += 1
+ else:
+ fdsIndexed += 1
+ except Exception as ingErr:
+ fdsFailed += 1
+ logger.error(
+ "feature.bootstrap: ingestion error fds=%s row=%s: %s",
+ fdsId, rowId, ingErr,
+ )
+
+ offset += batchSize
+ if len(rows) < batchSize:
+ break
+
+ totalIndexed += fdsIndexed
+ totalSkipped += fdsSkipped
+ totalFailed += fdsFailed
+
+ fdsResults.append({
+ "featureDataSourceId": fdsId,
+ "tableName": tableName,
+ "featureCode": featureCode,
+ "indexed": fdsIndexed,
+ "skippedDuplicate": fdsSkipped,
+ "failed": fdsFailed,
+ })
+
+ except Exception as fdsErr:
+ logger.error(
+ "feature.bootstrap: error processing FDS %s (%s.%s): %s",
+ fdsId, featureCode, tableName, fdsErr, exc_info=True,
+ )
+ fdsResults.append({
+ "featureDataSourceId": fdsId,
+ "tableName": tableName,
+ "featureCode": featureCode,
+ "error": str(fdsErr),
+ })
+
+ progressCb(100, messageKey="Feature-Daten-Sync abgeschlossen.")
+
+ return {
+ "workspaceInstanceId": workspaceInstanceId,
+ "indexed": totalIndexed,
+ "skippedDuplicate": totalSkipped,
+ "failed": totalFailed,
+ "dataSources": fdsResults,
+ }
diff --git a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py b/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py
deleted file mode 100644
index 0deae777..00000000
--- a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright (c) 2025 Patrick Motsch
-# All rights reserved.
-"""DEPRECATED: Use `_inheritFlags.getEffectiveFlag()` directly.
-
-Thin shim to the new cascade-inherit helper. Kept so external callers don't
-break on import — internal walkers consume pre-resolved dicts via
-`_loadRagEnabledDataSources`.
-"""
-
-from __future__ import annotations
-
-from typing import Any, Dict, List
-
-from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
-
-
-def resolveEffectiveNeutralize(
- ds: Dict[str, Any],
- allDataSources: List[Dict[str, Any]],
-) -> bool:
- """DEPRECATED: use `getEffectiveFlag(ds, 'neutralize', allDataSources)`."""
- value = getEffectiveFlag(ds, "neutralize", allDataSources)
- return bool(value)
-
-
-def resolveEffectiveRagIndexEnabled(
- ds: Dict[str, Any],
- allDataSources: List[Dict[str, Any]],
-) -> bool:
- """DEPRECATED: use `getEffectiveFlag(ds, 'ragIndexEnabled', allDataSources)`."""
- value = getEffectiveFlag(ds, "ragIndexEnabled", allDataSources)
- return bool(value)
diff --git a/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py
index 8e65fd0f..41d9d458 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py
@@ -15,8 +15,9 @@ up with "Job stuck at 10% for 10h" zombies.
These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs
on a worker thread so the loop stays responsive. Every wrapped call also
-emits a short start/done log line, so when something hangs we know the
-exact item that caused it (path, size, mime).
+emits start/done log lines at DEBUG so normal INFO logs stay quiet; for
+stuck-job triage, enable DEBUG for this module — the last
+``walker.item.start`` before a hang still pinpoints the item (path, size, mime).
"""
from __future__ import annotations
@@ -48,7 +49,7 @@ async def downloadWithTimeout(
used in log messages so we can pinpoint the offending item in case of a
hang or timeout.
"""
- logger.info("walker.download.start %s timeout=%ds", label, timeoutSeconds)
+ logger.debug("walker.download.start %s timeout=%ds", label, timeoutSeconds)
try:
result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds)
logger.debug("walker.download.done %s", label)
@@ -71,7 +72,7 @@ async def extractWithTimeout(
keep running until the process exits — but at least the walker proceeds
to the next item instead of freezing forever.
"""
- logger.info("walker.extract.start %s timeout=%ds", label, timeoutSeconds)
+ logger.debug("walker.extract.start %s timeout=%ds", label, timeoutSeconds)
try:
result = await asyncio.wait_for(
asyncio.to_thread(syncFn, *args),
@@ -102,15 +103,15 @@ async def ingestWithTimeout(
def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None:
- """Log that processing of one item is about to begin.
+ """Log that processing of one item is about to begin (DEBUG).
When the worker hangs, the LAST `walker.item.start` line in the log
- points to the exact item that caused the freeze. This is the single
- most valuable diagnostic for stuck-job triage.
+ points to the exact item that caused the freeze. Enable DEBUG for this
+ module during triage.
"""
parts = [f"walker.item.start service={service} path={label}"]
if sizeBytes is not None:
parts.append(f"size={sizeBytes}")
if mime:
parts.append(f"mime={mime}")
- logger.info(" ".join(parts))
+ logger.debug(" ".join(parts))
diff --git a/scripts/script_migrate_user_uid.py b/scripts/script_migrate_user_uid.py
new file mode 100644
index 00000000..07f9b443
--- /dev/null
+++ b/scripts/script_migrate_user_uid.py
@@ -0,0 +1,274 @@
+#!/usr/bin/env python3
+"""One-time migration: Reassign all DB references from an old user UID to a new UID.
+
+When a user is re-created in PORTA (same username, new UUID), all existing records
+still reference the old UUID. This script scans ALL registered databases and tables
+for VARCHAR columns containing the old UID and updates them to the new UID.
+
+Affected columns include:
+ - sysCreatedBy / sysModifiedBy (on every table via PowerOnModel)
+ - userId, revokedBy, createdByUserId, publishedBy, triggeredBy, assignedTo, etc.
+
+The script auto-detects the new UID from the UserInDB table by username.
+
+Usage:
+ # Dry-run (default) — shows what would change, no writes:
+ python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid
+
+ # Execute for real:
+ python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid --execute
+"""
+
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import List, Optional, Tuple
+
+scriptPath = Path(__file__).resolve()
+gatewayPath = scriptPath.parent.parent
+sys.path.insert(0, str(gatewayPath))
+os.chdir(str(gatewayPath))
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
+logger = logging.getLogger(__name__)
+
+import psycopg2
+import psycopg2.extras
+from modules.shared.configuration import APP_CONFIG
+
+
+ALL_DATABASES = [
+ "poweron_app",
+ "poweron_chat",
+ "poweron_management",
+ "poweron_knowledge",
+ "poweron_billing",
+ "poweron_workspace",
+ "poweron_graphicaleditor",
+ "poweron_chatbot",
+ "poweron_trustee",
+ "poweron_commcoach",
+ "poweron_neutralization",
+ "poweron_realestate",
+ "poweron_teamsbot",
+]
+
+
+def _getConnection(dbName: str):
+ return psycopg2.connect(
+ host=APP_CONFIG.get("DB_HOST", "localhost"),
+ port=int(APP_CONFIG.get("DB_PORT", "5432")),
+ database=dbName,
+ user=APP_CONFIG.get("DB_USER"),
+ password=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ client_encoding="utf8",
+ )
+
+
+def _getTablesInDb(conn) -> List[str]:
+ with conn.cursor() as cur:
+ cur.execute("""
+ SELECT table_name FROM information_schema.tables
+ WHERE table_schema = 'public'
+ AND table_type = 'BASE TABLE'
+ AND table_name NOT LIKE '\\_%%'
+ ORDER BY table_name
+ """)
+ return [row[0] for row in cur.fetchall()]
+
+
+def _getVarcharColumns(conn, tableName: str) -> List[str]:
+ """Get all VARCHAR/TEXT columns for a table (potential user-ID carriers)."""
+ with conn.cursor() as cur:
+ cur.execute("""
+ SELECT column_name FROM information_schema.columns
+ WHERE table_schema = 'public'
+ AND table_name = %s
+ AND data_type IN ('character varying', 'text')
+ ORDER BY ordinal_position
+ """, (tableName,))
+ return [row[0] for row in cur.fetchall()]
+
+
+def _countMatches(conn, tableName: str, columnName: str, oldUid: str) -> int:
+ with conn.cursor() as cur:
+ cur.execute(
+ f'SELECT COUNT(*) FROM "{tableName}" WHERE "{columnName}" = %s',
+ (oldUid,),
+ )
+ return cur.fetchone()[0]
+
+
+def _updateColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int:
+ with conn.cursor() as cur:
+ cur.execute(
+ f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE "{columnName}" = %s',
+ (newUid, oldUid),
+ )
+ return cur.rowcount
+
+
+def _lookupNewUid(username: str) -> Optional[str]:
+ """Find the current UID for a username in poweron_app.UserInDB."""
+ conn = _getConnection("poweron_app")
+ try:
+ with conn.cursor() as cur:
+ cur.execute(
+ 'SELECT "id" FROM "UserInDB" WHERE "username" = %s',
+ (username,),
+ )
+ row = cur.fetchone()
+ return row[0] if row else None
+ finally:
+ conn.close()
+
+
+def _scanJsonbForUid(conn, tableName: str, columnName: str, oldUid: str) -> int:
+ """Count JSONB fields that contain the old UID as a text value anywhere."""
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""SELECT COUNT(*) FROM "{tableName}"
+ WHERE "{columnName}"::text LIKE %s""",
+ (f"%{oldUid}%",),
+ )
+ return cur.fetchone()[0]
+
+
+def _updateJsonbColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int:
+ """Replace old UID inside JSONB columns using text replacement."""
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""UPDATE "{tableName}"
+ SET "{columnName}" = REPLACE("{columnName}"::text, %s, %s)::jsonb
+ WHERE "{columnName}"::text LIKE %s""",
+ (oldUid, newUid, f"%{oldUid}%"),
+ )
+ return cur.rowcount
+
+
+def _getJsonbColumns(conn, tableName: str) -> List[str]:
+ """Get all JSONB columns for a table."""
+ with conn.cursor() as cur:
+ cur.execute("""
+ SELECT column_name FROM information_schema.columns
+ WHERE table_schema = 'public'
+ AND table_name = %s
+ AND data_type = 'jsonb'
+ ORDER BY ordinal_position
+ """, (tableName,))
+ return [row[0] for row in cur.fetchall()]
+
+
+def migrate(username: str, oldUid: str, execute: bool = False):
+ newUid = _lookupNewUid(username)
+ if not newUid:
+ logger.error(f"User '{username}' not found in UserInDB. Cannot determine new UID.")
+ sys.exit(1)
+
+ if newUid == oldUid:
+ logger.error(f"Old UID and new UID are identical ({oldUid}). Nothing to migrate.")
+ sys.exit(1)
+
+ logger.info(f"Migration: user '{username}'")
+ logger.info(f" Old UID: {oldUid}")
+ logger.info(f" New UID: {newUid}")
+ logger.info(f" Mode: {'EXECUTE' if execute else 'DRY-RUN'}")
+ logger.info("")
+
+ totalUpdated = 0
+ findings: List[Tuple[str, str, str, int]] = []
+
+ for dbName in ALL_DATABASES:
+ try:
+ conn = _getConnection(dbName)
+ except Exception as e:
+ logger.warning(f" Cannot connect to {dbName}: {e}")
+ continue
+
+ try:
+ conn.autocommit = False
+ tables = _getTablesInDb(conn)
+
+ for tableName in tables:
+ varcharCols = _getVarcharColumns(conn, tableName)
+ for col in varcharCols:
+ count = _countMatches(conn, tableName, col, oldUid)
+ if count > 0:
+ findings.append((dbName, tableName, col, count))
+ if execute:
+ updated = _updateColumn(conn, tableName, col, oldUid, newUid)
+ totalUpdated += updated
+ logger.info(f" [UPDATED] {dbName}.{tableName}.{col}: {updated} rows")
+ else:
+ logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col}: {count} rows would be updated")
+
+ jsonbCols = _getJsonbColumns(conn, tableName)
+ for col in jsonbCols:
+ count = _scanJsonbForUid(conn, tableName, col, oldUid)
+ if count > 0:
+ findings.append((dbName, tableName, f"{col} (JSONB)", count))
+ if execute:
+ _updateJsonbColumn(conn, tableName, col, oldUid, newUid)
+ totalUpdated += count
+ logger.info(f" [UPDATED] {dbName}.{tableName}.{col} (JSONB): {count} rows")
+ else:
+ logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col} (JSONB): {count} rows would be updated")
+
+ if execute:
+ conn.commit()
+ else:
+ conn.rollback()
+ except Exception as e:
+ conn.rollback()
+ logger.error(f" Error processing {dbName}: {e}")
+ finally:
+ conn.close()
+
+ logger.info("")
+ logger.info("=" * 70)
+ logger.info("SUMMARY")
+ logger.info("=" * 70)
+ if not findings:
+ logger.info(" No references to old UID found in any database.")
+ else:
+ logger.info(f" Found {len(findings)} column(s) with references to old UID:")
+ for dbName, tableName, col, count in findings:
+ logger.info(f" {dbName}.{tableName}.{col}: {count} rows")
+ logger.info("")
+ if execute:
+ logger.info(f" Total rows updated: {totalUpdated}")
+ else:
+ logger.info(f" Total rows that would be updated: {sum(c for _, _, _, c in findings)}")
+ logger.info("")
+ logger.info(" To apply changes, re-run with --execute")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Migrate all DB references from old user UID to new UID."
+ )
+ parser.add_argument(
+ "--username",
+ required=True,
+ help="Username to migrate (e.g. 'patrick.helvetia'). Used to look up the new UID.",
+ )
+ parser.add_argument(
+ "--old-uid",
+ required=True,
+ help="The old UUID that is orphaned in the database.",
+ )
+ parser.add_argument(
+ "--execute",
+ action="store_true",
+ default=False,
+ help="Actually perform the migration. Without this flag, only a dry-run is done.",
+ )
+ args = parser.parse_args()
+
+ migrate(username=args.username, oldUid=args.old_uid, execute=args.execute)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
index 57094760..5fb505d7 100644
--- a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
+++ b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
@@ -30,6 +30,7 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
DatabaseQueryError,
+ _stripNulBytesFromStr,
)
@@ -164,3 +165,12 @@ class TestGetRecordFailLoud:
assert excinfo.value.table == "DummyTable"
conn.rollback.assert_called_once()
+
+
+class TestStripNulBytesFromStr:
+ def test_removesNul(self):
+ assert _stripNulBytesFromStr("a\x00b") == "ab"
+
+ def test_passthroughNonStr(self):
+ assert _stripNulBytesFromStr(None) is None
+ assert _stripNulBytesFromStr(7) == 7
diff --git a/tests/unit/services/test_buildTree.py b/tests/unit/services/test_buildTree.py
new file mode 100644
index 00000000..5a2bacb4
--- /dev/null
+++ b/tests/unit/services/test_buildTree.py
@@ -0,0 +1,359 @@
+"""Unit tests for the generic UDB tree builder.
+
+Verifies key encoding/decoding and that children for parent keys with
+existing handlers (top-level, conn, mgrp, feat) are produced with the
+correct effective-flag triplet.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import unittest
+from unittest.mock import MagicMock, patch
+
+from modules.serviceCenter.services.serviceKnowledge import _buildTree
+
+
+class TestKeyCoding(unittest.TestCase):
+ def test_encode_decode_roundtrip(self):
+ key = _buildTree._encode("ds", "conn-1", "sharepointFolder", "/sites/x")
+ kind, parts = _buildTree._decode(key)
+ self.assertEqual(kind, "ds")
+ self.assertEqual(parts, ["conn-1", "sharepointFolder", "/sites/x"])
+
+ def test_top_level_kinds(self):
+ self.assertEqual(_buildTree._decode("conn|abc")[0], "conn")
+ self.assertEqual(_buildTree._decode("mgrp|m1")[0], "mgrp")
+ self.assertEqual(_buildTree._decode("feat|m1|trustee|fi-1")[1], ["m1", "trustee", "fi-1"])
+
+
+class TestEffectiveTriplets(unittest.TestCase):
+ def test_ds_triplet_no_record_returns_defaults(self):
+ result = _buildTree._effectiveTripletDs("c", "msft", "/", [])
+ self.assertEqual(result, {
+ "effectiveNeutralize": False,
+ "effectiveScope": "personal",
+ "effectiveRagIndexEnabled": False,
+ })
+
+ def test_ds_triplet_inherits_from_root(self):
+ root = {
+ "id": "r", "connectionId": "c", "sourceType": "msft", "path": "/",
+ "neutralize": True, "scope": "mandate", "ragIndexEnabled": True,
+ }
+ result = _buildTree._effectiveTripletDs("c", "sharepointFolder", "/sites/x", [root])
+ self.assertEqual(result["effectiveNeutralize"], True)
+ self.assertEqual(result["effectiveScope"], "mandate")
+ self.assertEqual(result["effectiveRagIndexEnabled"], True)
+
+ def test_fds_triplet_inherits_from_workspace_wildcard(self):
+ ws = {
+ "id": "ws", "workspaceInstanceId": "ws-inst", "featureInstanceId": "fi1",
+ "tableName": "*", "recordFilter": None, "neutralize": True,
+ "scope": "mandate", "ragIndexEnabled": True,
+ }
+ result = _buildTree._effectiveTripletFds("fi1", "Pos", None, [ws])
+ self.assertEqual(result["effectiveNeutralize"], True)
+ self.assertEqual(result["effectiveScope"], "mandate")
+ self.assertEqual(result["effectiveRagIndexEnabled"], True)
+
+
+class TestRecordLookup(unittest.TestCase):
+ def test_finds_ds_record_by_normalised_path(self):
+ rec = {"id": "x", "connectionId": "c", "sourceType": "msft", "path": "/folder"}
+ self.assertEqual(_buildTree._findDsRecord([rec], "c", "msft", "/folder/").get("id"), "x")
+ self.assertIsNone(_buildTree._findDsRecord([rec], "c", "msft", "/other"))
+
+ def test_finds_fds_record_with_matching_filter(self):
+ rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "Pos", "recordFilter": {"id": "5"}}
+ self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "5"}).get("id"), "f")
+ self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "99"}))
+
+ def test_fds_record_with_none_filter_matches_only_none(self):
+ rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "*", "recordFilter": None}
+ self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "*", None).get("id"), "f")
+ self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "*", {"id": "1"}))
+
+
+class TestGetChildrenForParents(unittest.TestCase):
+ """End-to-end orchestrator test with mocked dependencies."""
+
+ def _runAsync(self, coro):
+ return asyncio.get_event_loop().run_until_complete(coro)
+
+ def test_unknown_parent_key_returns_empty_list(self):
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ mockRoot.return_value = rootIf
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = "m1"
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", ["bogus|key"], ctx)
+ )
+ self.assertEqual(result["bogus|key"], [])
+
+ def test_top_level_emits_personal_root_first(self):
+ """Top-level emits personalRoot first, then mandate-group nodes inline."""
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ rootIf.getUserMandates.return_value = []
+ mockRoot.return_value = rootIf
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = "m1"
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", [None], ctx)
+ )
+ children = result["__root__"]
+ self.assertGreaterEqual(len(children), 1)
+ personalRoot = children[0]
+ self.assertEqual(personalRoot["key"], "personalRoot")
+ self.assertEqual(personalRoot["kind"], "synthRoot")
+ self.assertIsNone(personalRoot["parentKey"])
+ self.assertTrue(personalRoot["hasChildren"])
+ self.assertTrue(personalRoot["defaultExpanded"])
+
+
+class TestTopLevelLayout(unittest.TestCase):
+ """Tests for the flat top-level layout (personalRoot + mandate groups)."""
+
+ def _runAsync(self, coro):
+ return asyncio.get_event_loop().run_until_complete(coro)
+
+ def test_personal_root_carries_neutral_default_triplet(self):
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ rootIf.getUserMandates.return_value = []
+ mockRoot.return_value = rootIf
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = "m1"
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", [None], ctx)
+ )
+ personalRoot = result["__root__"][0]
+ self.assertFalse(personalRoot["effectiveNeutralize"])
+ self.assertEqual(personalRoot["effectiveScope"], "personal")
+ self.assertFalse(personalRoot["effectiveRagIndexEnabled"])
+ self.assertFalse(personalRoot["supportsRag"])
+ self.assertFalse(personalRoot["canBeAdded"])
+ self.assertIsNone(personalRoot["dataSourceId"])
+ self.assertIsNone(personalRoot["modelType"])
+
+ def test_personal_root_emits_active_connection_with_correct_parent(self):
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
+ patch("modules.serviceCenter.getService") as mockGetService:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ mockRoot.return_value = rootIf
+
+ chatService = MagicMock()
+ chatService.getUserConnections.return_value = [{
+ "id": "conn-1",
+ "status": "active",
+ "authority": "msft",
+ "externalEmail": "user@example.com",
+ }]
+ mockGetService.return_value = chatService
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = "m1"
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx)
+ )
+ children = result["personalRoot"]
+ self.assertEqual(len(children), 1)
+ self.assertEqual(children[0]["key"], "conn|conn-1")
+ self.assertEqual(children[0]["kind"], "connection")
+ self.assertEqual(children[0]["parentKey"], "personalRoot")
+ self.assertEqual(children[0]["label"], "user@example.com")
+ self.assertTrue(children[0]["supportsRag"])
+
+ def test_personal_root_skips_inactive_connection(self):
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
+ patch("modules.serviceCenter.getService") as mockGetService:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ mockRoot.return_value = rootIf
+
+ chatService = MagicMock()
+ chatService.getUserConnections.return_value = [
+ {"id": "c1", "status": "active", "authority": "msft", "externalEmail": "a"},
+ {"id": "c2", "status": "expired", "authority": "google", "externalEmail": "b"},
+ ]
+ mockGetService.return_value = chatService
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = "m1"
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx)
+ )
+ self.assertEqual(len(result["personalRoot"]), 1)
+ self.assertEqual(result["personalRoot"][0]["connectionId"], "c1")
+
+ def test_mandate_groups_emitted_inline_at_top_level(self):
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
+ patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ userMandate = MagicMock()
+ userMandate.mandateId = "m1"
+ rootIf.getUserMandates.return_value = [userMandate]
+ featureInst = MagicMock()
+ featureInst.id = "fi-1"
+ featureInst.featureCode = "trustee"
+ featureInst.enabled = True
+ rootIf.getFeatureInstancesByMandate.return_value = [featureInst]
+ featureAccess = MagicMock()
+ featureAccess.enabled = True
+ rootIf.getFeatureAccess.return_value = featureAccess
+ mockRoot.return_value = rootIf
+
+ catalog = MagicMock()
+ catalog.getFeaturesWithDataObjects.return_value = ["trustee"]
+ mockCatalog.return_value = catalog
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = None
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", [None], ctx)
+ )
+ children = result["__root__"]
+ byKey = {c["key"]: c for c in children}
+ self.assertIn("personalRoot", byKey)
+ self.assertIn("mgrp|m1", byKey)
+ mgroup = byKey["mgrp|m1"]
+ self.assertEqual(mgroup["kind"], "mandateGroup")
+ self.assertIsNone(mgroup["parentKey"])
+ self.assertEqual(mgroup["mandateId"], "m1")
+ self.assertTrue(mgroup["defaultExpanded"])
+ self.assertFalse(mgroup["supportsRag"])
+
+ def test_top_level_omits_mandates_without_data_features(self):
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
+ patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ userMandate = MagicMock()
+ userMandate.mandateId = "m1"
+ rootIf.getUserMandates.return_value = [userMandate]
+ rootIf.getFeatureInstancesByMandate.return_value = []
+ mockRoot.return_value = rootIf
+
+ catalog = MagicMock()
+ catalog.getFeaturesWithDataObjects.return_value = ["trustee"]
+ mockCatalog.return_value = catalog
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = None
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", [None], ctx)
+ )
+ keys = [c["key"] for c in result["__root__"]]
+ self.assertEqual(keys, ["personalRoot"])
+
+ def test_personal_root_listed_first_via_display_order(self):
+ with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
+ patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog:
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = []
+ userMandate = MagicMock()
+ userMandate.mandateId = "m1"
+ rootIf.getUserMandates.return_value = [userMandate]
+ featureInst = MagicMock()
+ featureInst.id = "fi-1"
+ featureInst.featureCode = "trustee"
+ featureInst.enabled = True
+ rootIf.getFeatureInstancesByMandate.return_value = [featureInst]
+ featureAccess = MagicMock()
+ featureAccess.enabled = True
+ rootIf.getFeatureAccess.return_value = featureAccess
+ mockRoot.return_value = rootIf
+
+ catalog = MagicMock()
+ catalog.getFeaturesWithDataObjects.return_value = ["trustee"]
+ mockCatalog.return_value = catalog
+
+ ctx = MagicMock()
+ ctx.user.id = "u1"
+ ctx.mandateId = None
+
+ result = self._runAsync(
+ _buildTree.getChildrenForParents("inst-1", [None], ctx)
+ )
+ children = result["__root__"]
+ self.assertEqual(children[0]["key"], "personalRoot")
+ self.assertEqual(children[0]["displayOrder"], 0)
+
+
+class TestFeatureTableFields(unittest.TestCase):
+ """Per-column field expansion under a feature data-source table."""
+
+ def test_emits_one_node_per_field(self):
+ nodes = _buildTree._featureTableFields(
+ parentKey="fdstbl|fi-1|TrusteePosition",
+ featureInstanceId="fi-1",
+ tableName="TrusteePosition",
+ fieldNames=["id", "valuta", "company"],
+ allFds=[],
+ )
+ self.assertEqual(len(nodes), 3)
+ self.assertEqual(nodes[0]["kind"], "fdsField")
+ self.assertEqual(nodes[0]["fieldName"], "id")
+ self.assertEqual(nodes[0]["parentKey"], "fdstbl|fi-1|TrusteePosition")
+ self.assertEqual(nodes[0]["key"], "fdsfld|fi-1|TrusteePosition|id")
+ self.assertFalse(nodes[0]["hasChildren"])
+ self.assertFalse(nodes[0]["supportsRag"])
+
+ def test_field_neutralize_inherits_from_table_blanket(self):
+ rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1",
+ "tableName": "TrusteePosition", "recordFilter": None,
+ "neutralize": True, "neutralizeFields": None,
+ "scope": None, "ragIndexEnabled": False}
+ nodes = _buildTree._featureTableFields(
+ parentKey="fdstbl|fi-1|TrusteePosition",
+ featureInstanceId="fi-1",
+ tableName="TrusteePosition",
+ fieldNames=["email", "company"],
+ allFds=[rec],
+ )
+ self.assertTrue(nodes[0]["effectiveNeutralize"])
+ self.assertTrue(nodes[1]["effectiveNeutralize"])
+
+ def test_field_neutralize_explicit_via_neutralize_fields(self):
+ rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1",
+ "tableName": "TrusteePosition", "recordFilter": None,
+ "neutralize": False, "neutralizeFields": ["email"],
+ "scope": None, "ragIndexEnabled": False}
+ nodes = _buildTree._featureTableFields(
+ parentKey="fdstbl|fi-1|TrusteePosition",
+ featureInstanceId="fi-1",
+ tableName="TrusteePosition",
+ fieldNames=["email", "company"],
+ allFds=[rec],
+ )
+ byField = {n["fieldName"]: n for n in nodes}
+ self.assertTrue(byField["email"]["effectiveNeutralize"])
+ self.assertFalse(byField["company"]["effectiveNeutralize"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit/services/test_inheritFlags.py b/tests/unit/services/test_inheritFlags.py
index b177e767..98e6fb41 100644
--- a/tests/unit/services/test_inheritFlags.py
+++ b/tests/unit/services/test_inheritFlags.py
@@ -1,12 +1,12 @@
"""Unit tests for `_inheritFlags` cascade-inherit helpers.
Verifies:
-- getEffectiveFlag walks ancestors via path-prefix matching
-- root default is False (or 'personal' for scope) when nothing explicit in chain
-- only same-connectionId AND same-sourceType ancestors are considered
-- cascadeResetDescendants only touches descendants with explicit values for THAT flag
-- '/' is treated as ancestor of every non-root path
-- '/foo' is NOT ancestor of '/foobar' (must require '/' separator)
+- getEffectiveFlag mode='walk': walks ancestors via path-prefix matching
+- getEffectiveFlag mode='aggregate': returns 'mixed' when subtree diverges
+- cascadeResetDescendants: bottom-up reset returning List[str]
+- cascadeResetDescendantsFds: same for FeatureDataSource
+- collectAncestorChain / collectAncestorChainFds: ancestor discovery
+- buildEffectiveByConnection / buildEffectiveByWorkspaceFds: batch compute
"""
from __future__ import annotations
@@ -33,7 +33,26 @@ def _ds(idVal: str, path: str, **flags) -> dict:
return base
-class TestEffectiveFlag(unittest.TestCase):
+def _fds(idVal: str, *, tableName: str, recordFilter=None, featureInstanceId="fi-1", **flags) -> dict:
+ """Build a FeatureDataSource dict fixture."""
+ base = {
+ "id": idVal,
+ "workspaceInstanceId": "ws-1",
+ "featureInstanceId": featureInstanceId,
+ "tableName": tableName,
+ "recordFilter": recordFilter,
+ "neutralize": None,
+ "scope": None,
+ }
+ base.update(flags)
+ return base
+
+
+# ===========================================================================
+# DataSource: getEffectiveFlag mode='walk'
+# ===========================================================================
+
+class TestEffectiveFlagWalk(unittest.TestCase):
def test_explicit_own_value_wins(self):
root = _ds("r", "/", neutralize=False)
leaf = _ds("l", "/folder/sub", neutralize=True)
@@ -65,7 +84,6 @@ class TestEffectiveFlag(unittest.TestCase):
self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherType, leaf]))
def test_path_separator_required(self):
- """`/foo` must NOT be ancestor of `/foobar` (no shared `/` boundary)."""
notAncestor = _ds("a", "/foo", neutralize=True)
leaf = _ds("l", "/foobar")
self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [notAncestor, leaf]))
@@ -90,32 +108,101 @@ class TestEffectiveFlag(unittest.TestCase):
_inheritFlags.getEffectiveFlag(leaf, "unknownFlag", [leaf])
def test_explicit_false_overrides_inherited_true(self):
- """Explicit False on a child must NOT cascade up to True from an ancestor."""
root = _ds("r", "/", neutralize=True)
leaf = _ds("l", "/folder", neutralize=False)
self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf]))
def test_connection_root_inherits_cross_sourcetype(self):
- """Connection-root (sourceType=authority, path='/') is ancestor of all DS in that connection."""
connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
spService = _ds("sp", "/", sourceType="sharepointFolder")
olService = _ds("ol", "/", sourceType="outlookFolder")
- self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", [connRoot, spService, olService]))
- self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", [connRoot, spService, olService]))
+ allDs = [connRoot, spService, olService]
+ self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", allDs))
+ self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", allDs))
def test_same_sourcetype_ancestor_wins_over_connection_root(self):
- """A same-sourceType service-root ancestor beats the connection-root."""
connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
spRoot = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
spLeaf = _ds("spl", "/sites/x", sourceType="sharepointFolder")
self.assertFalse(_inheritFlags.getEffectiveFlag(spLeaf, "neutralize", [connRoot, spRoot, spLeaf]))
def test_connection_root_does_not_self_inherit(self):
- """Connection-root has no ancestor — does not infinite-loop on itself."""
connRoot = _ds("conn", "/", sourceType="msft")
self.assertFalse(_inheritFlags.getEffectiveFlag(connRoot, "neutralize", [connRoot]))
+# ===========================================================================
+# DataSource: getEffectiveFlag mode='aggregate'
+# ===========================================================================
+
+class TestEffectiveFlagAggregate(unittest.TestCase):
+ def test_leaf_without_descendants_returns_concrete(self):
+ leaf = _ds("l", "/folder", neutralize=True)
+ self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [leaf], mode="aggregate"))
+
+ def test_all_descendants_same_returns_concrete(self):
+ root = _ds("r", "/", neutralize=True)
+ child1 = _ds("c1", "/a", neutralize=True)
+ child2 = _ds("c2", "/b") # inherits True from root
+ allDs = [root, child1, child2]
+ self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"))
+
+ def test_divergent_descendants_returns_mixed(self):
+ root = _ds("r", "/", neutralize=True)
+ child1 = _ds("c1", "/a", neutralize=False)
+ child2 = _ds("c2", "/b") # inherits True from root
+ allDs = [root, child1, child2]
+ self.assertEqual(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"), "mixed")
+
+ def test_mixed_scope(self):
+ root = _ds("r", "/", scope="personal")
+ child1 = _ds("c1", "/a", scope="team")
+ child2 = _ds("c2", "/b") # inherits personal from root
+ allDs = [root, child1, child2]
+ self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "mixed")
+
+ def test_all_scope_same_explicit_returns_concrete(self):
+ root = _ds("r", "/", scope="team")
+ child1 = _ds("c1", "/a", scope="team")
+ child2 = _ds("c2", "/b") # inherits team
+ allDs = [root, child1, child2]
+ self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "team")
+
+ def test_connection_root_aggregate_cross_sourcetype(self):
+ connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
+ spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
+ olInherit = _ds("ol", "/", sourceType="outlookFolder") # inherits True
+ allDs = [connRoot, spExplicit, olInherit]
+ self.assertEqual(
+ _inheritFlags.getEffectiveFlag(connRoot, "neutralize", allDs, mode="aggregate"),
+ "mixed",
+ )
+
+ def test_mid_level_aggregate_only_considers_own_subtree(self):
+ root = _ds("r", "/", neutralize=True)
+ mid = _ds("m", "/folder", neutralize=True)
+ midChild = _ds("mc", "/folder/sub", neutralize=True)
+ sibling = _ds("s", "/other", neutralize=False) # not under mid
+ allDs = [root, mid, midChild, sibling]
+ # mid's subtree is just midChild(True) + mid(True) = uniform
+ self.assertTrue(_inheritFlags.getEffectiveFlag(mid, "neutralize", allDs, mode="aggregate"))
+ # root's subtree includes sibling(False) = mixed
+ self.assertEqual(
+ _inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"),
+ "mixed",
+ )
+
+ def test_walk_mode_never_returns_mixed(self):
+ root = _ds("r", "/", neutralize=True)
+ child = _ds("c", "/a", neutralize=False)
+ allDs = [root, child]
+ self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="walk"))
+
+
+# ===========================================================================
+# DataSource: cascadeResetDescendants (bottom-up, List[str])
+# ===========================================================================
+
class TestCascadeReset(unittest.TestCase):
def _makeRootIf(self, dataSources: List[dict]):
rootIf = MagicMock()
@@ -127,54 +214,76 @@ class TestCascadeReset(unittest.TestCase):
rootIf.db.recordModify = MagicMock(side_effect=_modify)
return rootIf, modified
+ def test_returns_list_of_ids(self):
+ parent = _ds("p", "/sites", neutralize=True)
+ child = _ds("c1", "/sites/folder1", neutralize=False)
+ rootIf, _ = self._makeRootIf([parent, child])
+ result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+ self.assertIsInstance(result, list)
+ self.assertEqual(result, ["c1"])
+
def test_resets_only_explicit_descendants(self):
parent = _ds("p", "/sites", neutralize=True)
explicitChild = _ds("c1", "/sites/folder1", neutralize=False)
- inheritChild = _ds("c2", "/sites/folder2") # inherit -> not touched
- sibling = _ds("s", "/other", neutralize=True) # NOT a descendant
+ inheritChild = _ds("c2", "/sites/folder2")
+ sibling = _ds("s", "/other", neutralize=True)
rootIf, modified = self._makeRootIf([parent, explicitChild, inheritChild, sibling])
- affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+ result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
- self.assertEqual(affected, 1)
+ self.assertEqual(result, ["c1"])
self.assertEqual(modified, [("c1", {"neutralize": None})])
- def test_does_not_touch_other_flags(self):
- parent = _ds("p", "/sites", neutralize=True)
- child = _ds("c", "/sites/sub", neutralize=False, ragIndexEnabled=True)
+ def test_bottom_up_order(self):
+ """Deepest items are reset first."""
+ parent = _ds("p", "/", neutralize=True)
+ level1 = _ds("l1", "/a", neutralize=False)
+ level2 = _ds("l2", "/a/b", neutralize=False)
+ level3 = _ds("l3", "/a/b/c", neutralize=False)
+ rootIf, modified = self._makeRootIf([parent, level1, level2, level3])
+
+ result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+
+ self.assertEqual(result, ["l3", "l2", "l1"])
+
+ def test_deep_cascade_through_null_items(self):
+ """null items are skipped (no DB write) but cascade continues deeper."""
+ parent = _ds("p", "/", neutralize=True)
+ nullChild = _ds("n", "/a") # null — no write, but not a barrier
+ deepExplicit = _ds("d", "/a/b", neutralize=False)
+ rootIf, modified = self._makeRootIf([parent, nullChild, deepExplicit])
+
+ result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+
+ self.assertEqual(result, ["d"])
+ self.assertEqual(modified, [("d", {"neutralize": None})])
+
+ def test_does_not_modify_parent(self):
+ parent = _ds("p", "/", neutralize=True)
+ child = _ds("c", "/a", neutralize=False)
rootIf, modified = self._makeRootIf([parent, child])
-
_inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
-
- self.assertEqual(modified, [("c", {"neutralize": None})])
- # ragIndexEnabled and scope on the child must remain untouched.
-
- def test_does_not_cross_sourcetype(self):
- """Non-connection-root parents stay within their sourceType for cascade."""
- parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder")
- otherTypeDescendant = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder")
- rootIf, modified = self._makeRootIf([parent, otherTypeDescendant])
-
- affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
-
- self.assertEqual(affected, 0)
- self.assertEqual(modified, [])
+ self.assertNotIn("p", [m[0] for m in modified])
def test_connection_root_cascades_cross_sourcetype(self):
- """Toggle on connection-root cascades into every explicit DS of that connection."""
connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
olInherit = _ds("ol", "/", sourceType="outlookFolder")
- spLeafExplicit = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True)
- rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeafExplicit])
+ spLeaf = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True)
+ rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeaf])
- affected = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize")
+ result = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize")
- # spExplicit and spLeafExplicit had explicit values → reset. olInherit untouched.
- self.assertEqual(affected, 2)
- self.assertEqual({m[0] for m in modified}, {"sp", "sp-leaf"})
- for _, fields in modified:
- self.assertEqual(fields, {"neutralize": None})
+ self.assertEqual(set(result), {"sp", "sp-leaf"})
+ # sp-leaf is deeper, should come first
+ self.assertEqual(result[0], "sp-leaf")
+
+ def test_does_not_cross_sourcetype_for_non_authority(self):
+ parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder")
+ otherType = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder")
+ rootIf, modified = self._makeRootIf([parent, otherType])
+ result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
+ self.assertEqual(result, [])
def test_unknown_flag_raises(self):
parent = _ds("p", "/", neutralize=True)
@@ -183,57 +292,59 @@ class TestCascadeReset(unittest.TestCase):
_inheritFlags.cascadeResetDescendants(rootIf, parent, "unknownFlag")
-def _fds(idVal: str, *, tableName: str, recordFilter=None, **flags) -> dict:
- """Build a FeatureDataSource dict fixture."""
- base = {
- "id": idVal,
- "workspaceInstanceId": "ws-1",
- "tableName": tableName,
- "recordFilter": recordFilter,
- "neutralize": None,
- "scope": None,
- }
- base.update(flags)
- return base
+# ===========================================================================
+# DataSource: collectAncestorChain
+# ===========================================================================
+
+class TestCollectAncestorChain(unittest.TestCase):
+ def test_returns_nearest_first(self):
+ root = _ds("r", "/", neutralize=True)
+ mid = _ds("m", "/a")
+ leaf = _ds("l", "/a/b")
+ chain = _inheritFlags.collectAncestorChain(leaf, [root, mid, leaf])
+ self.assertEqual([_inheritFlags._getRecordValue(c, "id") for c in chain], ["m", "r"])
+
+ def test_connection_root_is_last(self):
+ connRoot = _ds("conn", "/", sourceType="msft")
+ spRoot = _ds("sp", "/", sourceType="sharepointFolder")
+ spLeaf = _ds("spl", "/sub", sourceType="sharepointFolder")
+ chain = _inheritFlags.collectAncestorChain(spLeaf, [connRoot, spRoot, spLeaf])
+ ids = [_inheritFlags._getRecordValue(c, "id") for c in chain]
+ self.assertEqual(ids, ["sp", "conn"])
+
+ def test_root_has_no_ancestors(self):
+ root = _ds("r", "/")
+ chain = _inheritFlags.collectAncestorChain(root, [root])
+ self.assertEqual(chain, [])
-class TestFdsClassifyAndAncestry(unittest.TestCase):
- def test_classify_workspace_wildcard(self):
- self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="*")), "workspace")
+# ===========================================================================
+# DataSource: buildEffectiveByConnection
+# ===========================================================================
- def test_classify_table_wildcard(self):
- self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="Pos")), "table")
+class TestBuildEffectiveByConnection(unittest.TestCase):
+ def test_walk_mode(self):
+ root = _ds("r", "/", neutralize=True)
+ child = _ds("c", "/a", neutralize=False)
+ leaf = _ds("l", "/a/b") # inherits False from child
+ result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="walk")
+ self.assertEqual(result, {"r": True, "c": False, "l": False})
- def test_classify_record_specific(self):
- rec = _fds("a", tableName="Pos", recordFilter={"id": "r-1"})
- self.assertEqual(_inheritFlags._fdsClassify(rec), "record")
-
- def test_workspace_is_ancestor_of_table_and_record(self):
- ws = _fds("ws", tableName="*")
- tbl = _fds("t", tableName="Pos")
- rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
- self.assertTrue(_inheritFlags._fdsIsAncestor(ws, tbl))
- self.assertTrue(_inheritFlags._fdsIsAncestor(ws, rec))
-
- def test_table_is_ancestor_of_record_same_table_only(self):
- tbl = _fds("t", tableName="Pos")
- recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"})
- recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"})
- self.assertTrue(_inheritFlags._fdsIsAncestor(tbl, recSame))
- self.assertFalse(_inheritFlags._fdsIsAncestor(tbl, recOther))
-
- def test_record_has_no_descendants(self):
- rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
- tbl = _fds("t", tableName="Pos")
- self.assertFalse(_inheritFlags._fdsIsAncestor(rec, tbl))
-
- def test_no_cross_workspace_ancestry(self):
- ws = _fds("ws", tableName="*", workspaceInstanceId="ws-A")
- rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, workspaceInstanceId="ws-B")
- self.assertFalse(_inheritFlags._fdsIsAncestor(ws, rec))
+ def test_aggregate_mode(self):
+ root = _ds("r", "/", neutralize=True)
+ child = _ds("c", "/a", neutralize=False)
+ leaf = _ds("l", "/a/b") # inherits False from child
+ result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="aggregate")
+ self.assertEqual(result["r"], "mixed")
+ self.assertEqual(result["c"], False)
+ self.assertEqual(result["l"], False)
-class TestFdsEffectiveFlag(unittest.TestCase):
+# ===========================================================================
+# FeatureDataSource: getEffectiveFlagFds
+# ===========================================================================
+
+class TestFdsEffectiveFlagWalk(unittest.TestCase):
def test_own_explicit_wins(self):
ws = _fds("ws", tableName="*", neutralize=False)
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
@@ -262,9 +373,50 @@ class TestFdsEffectiveFlag(unittest.TestCase):
def test_unknown_flag_raises(self):
rec = _fds("r", tableName="*")
with self.assertRaises(ValueError):
- _inheritFlags.getEffectiveFlagFds(rec, "ragIndexEnabled", [rec])
+ _inheritFlags.getEffectiveFlagFds(rec, "doesNotExist", [rec])
+class TestFdsEffectiveFlagAggregate(unittest.TestCase):
+ def test_leaf_without_descendants(self):
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
+ self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [rec], mode="aggregate"))
+
+ def test_all_descendants_same(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ tbl = _fds("t", tableName="Pos") # inherits True
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits True
+ allFds = [ws, tbl, rec]
+ self.assertTrue(_inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"))
+
+ def test_divergent_descendants_returns_mixed(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ tbl = _fds("t", tableName="Pos", neutralize=False)
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl
+ allFds = [ws, tbl, rec]
+ self.assertEqual(
+ _inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"),
+ "mixed",
+ )
+
+ def test_table_aggregate_own_subtree_only(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ tblA = _fds("tA", tableName="A", neutralize=True)
+ recA = _fds("rA", tableName="A", recordFilter={"id": "1"}, neutralize=True)
+ tblB = _fds("tB", tableName="B", neutralize=False)
+ allFds = [ws, tblA, recA, tblB]
+ # tblA subtree: all True
+ self.assertTrue(_inheritFlags.getEffectiveFlagFds(tblA, "neutralize", allFds, mode="aggregate"))
+ # ws subtree: mixed (tblB is False)
+ self.assertEqual(
+ _inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"),
+ "mixed",
+ )
+
+
+# ===========================================================================
+# FeatureDataSource: cascadeResetDescendantsFds (bottom-up, List[str])
+# ===========================================================================
+
class TestFdsCascadeReset(unittest.TestCase):
def _makeRootIf(self, fdses):
rootIf = MagicMock()
@@ -276,6 +428,14 @@ class TestFdsCascadeReset(unittest.TestCase):
rootIf.db.recordModify = MagicMock(side_effect=_modify)
return rootIf, modified
+ def test_returns_list_of_ids(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ tbl = _fds("t", tableName="Pos", neutralize=False)
+ rootIf, _ = self._makeRootIf([ws, tbl])
+ result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize")
+ self.assertIsInstance(result, list)
+ self.assertEqual(result, ["t"])
+
def test_workspace_cascades_to_all_explicit_descendants(self):
ws = _fds("ws", tableName="*", neutralize=True)
tblExplicit = _fds("t", tableName="Pos", neutralize=False)
@@ -283,10 +443,11 @@ class TestFdsCascadeReset(unittest.TestCase):
recExplicit = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
rootIf, modified = self._makeRootIf([ws, tblExplicit, tblInherit, recExplicit])
- affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize")
+ result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize")
- self.assertEqual(affected, 2)
- self.assertEqual({m[0] for m in modified}, {"t", "r"})
+ self.assertEqual(set(result), {"t", "r"})
+ # record is deeper (depth 2) than table (depth 1), should come first
+ self.assertEqual(result[0], "r")
def test_table_cascades_only_to_same_table_records(self):
tbl = _fds("t", tableName="Pos", neutralize=True)
@@ -294,25 +455,189 @@ class TestFdsCascadeReset(unittest.TestCase):
recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}, neutralize=False)
rootIf, modified = self._makeRootIf([tbl, recSame, recOther])
- affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize")
+ result = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize")
- self.assertEqual(affected, 1)
+ self.assertEqual(result, ["r1"])
self.assertEqual(modified, [("r1", {"neutralize": None})])
def test_record_has_no_cascade(self):
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
rootIf, modified = self._makeRootIf([rec])
- affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize")
- self.assertEqual(affected, 0)
- self.assertEqual(modified, [])
+ result = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize")
+ self.assertEqual(result, [])
def test_unknown_flag_raises(self):
ws = _fds("ws", tableName="*", neutralize=True)
rootIf, _ = self._makeRootIf([ws])
with self.assertRaises(ValueError):
- _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled")
+ _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "doesNotExist")
+# ===========================================================================
+# FeatureDataSource: collectAncestorChainFds
+# ===========================================================================
+
+class TestCollectAncestorChainFds(unittest.TestCase):
+ def test_record_has_table_then_workspace(self):
+ ws = _fds("ws", tableName="*")
+ tbl = _fds("t", tableName="Pos")
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ chain = _inheritFlags.collectAncestorChainFds(rec, [ws, tbl, rec])
+ ids = [c["id"] for c in chain]
+ self.assertEqual(ids, ["t", "ws"])
+
+ def test_table_has_only_workspace(self):
+ ws = _fds("ws", tableName="*")
+ tbl = _fds("t", tableName="Pos")
+ chain = _inheritFlags.collectAncestorChainFds(tbl, [ws, tbl])
+ self.assertEqual([c["id"] for c in chain], ["ws"])
+
+ def test_workspace_has_no_ancestors(self):
+ ws = _fds("ws", tableName="*")
+ chain = _inheritFlags.collectAncestorChainFds(ws, [ws])
+ self.assertEqual(chain, [])
+
+
+# ===========================================================================
+# FeatureDataSource: buildEffectiveByWorkspaceFds
+# ===========================================================================
+
+class TestBuildEffectiveByWorkspaceFds(unittest.TestCase):
+ def test_walk_mode(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ tbl = _fds("t", tableName="Pos", neutralize=False)
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl
+ result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="walk")
+ self.assertEqual(result, {"ws": True, "t": False, "r": False})
+
+ def test_aggregate_mode(self):
+ ws = _fds("ws", tableName="*", neutralize=True)
+ tbl = _fds("t", tableName="Pos", neutralize=False)
+ rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
+ result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="aggregate")
+ self.assertEqual(result["ws"], "mixed")
+ self.assertEqual(result["t"], False)
+ self.assertEqual(result["r"], False)
+
+
+# ===========================================================================
+# resolveEffectiveForPath (with and without own record)
+# ===========================================================================
+
+class TestResolveEffectiveForPath(unittest.TestCase):
+ def test_with_exact_record(self):
+ root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=False)
+ leaf = _ds("l", "/folder/sub", neutralize=False)
+ allDs = [root, leaf]
+ result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub", allDs)
+ self.assertEqual(result["effectiveNeutralize"], False)
+ self.assertEqual(result["effectiveScope"], "mandate")
+ self.assertEqual(result["effectiveRagIndexEnabled"], False)
+
+ def test_without_record_inherits_from_ancestor(self):
+ root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=True)
+ allDs = [root]
+ result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/deep/path/file.txt", allDs)
+ self.assertEqual(result["effectiveNeutralize"], True)
+ self.assertEqual(result["effectiveScope"], "mandate")
+ self.assertEqual(result["effectiveRagIndexEnabled"], True)
+
+ def test_without_record_inherits_from_closest_ancestor(self):
+ root = _ds("r", "/", neutralize=True, ragIndexEnabled=True)
+ mid = _ds("m", "/folder", neutralize=False, ragIndexEnabled=False)
+ allDs = [root, mid]
+ result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub/file.txt", allDs)
+ self.assertEqual(result["effectiveNeutralize"], False)
+ self.assertEqual(result["effectiveRagIndexEnabled"], False)
+
+ def test_without_record_no_ancestors_returns_defaults(self):
+ allDs: list = []
+ result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/path", allDs)
+ self.assertEqual(result["effectiveNeutralize"], False)
+ self.assertEqual(result["effectiveScope"], "personal")
+ self.assertEqual(result["effectiveRagIndexEnabled"], False)
+
+ def test_connection_root_covers_service_subtree(self):
+ connRoot = _ds("cr", "/", neutralize=True, sourceType="msft")
+ allDs = [connRoot]
+ result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/sites/intranet", allDs)
+ self.assertEqual(result["effectiveNeutralize"], True)
+
+ def test_exact_record_with_aggregate_mixed(self):
+ root = _ds("r", "/", neutralize=True)
+ leaf = _ds("l", "/sub", neutralize=False)
+ allDs = [root, leaf]
+ result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/", allDs, mode="aggregate")
+ self.assertEqual(result["effectiveNeutralize"], "mixed")
+
+
+class TestResolveEffectiveForFds(unittest.TestCase):
+ def test_with_exact_record(self):
+ ws = _fds("ws", tableName="*", neutralize=True, scope="mandate")
+ tbl = _fds("t", tableName="Pos", neutralize=False, scope="personal")
+ allFds = [ws, tbl]
+ result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds)
+ self.assertEqual(result["effectiveNeutralize"], False)
+ self.assertEqual(result["effectiveScope"], "personal")
+ self.assertEqual(result["effectiveRagIndexEnabled"], False)
+
+ def test_without_record_inherits_from_workspace_wildcard(self):
+ ws = _fds("ws", tableName="*", neutralize=True, scope="mandate", ragIndexEnabled=True)
+ allFds = [ws]
+ result = _inheritFlags.resolveEffectiveForFds("fi-1", "Unknown", None, allFds)
+ self.assertEqual(result["effectiveNeutralize"], True)
+ self.assertEqual(result["effectiveScope"], "mandate")
+ self.assertEqual(result["effectiveRagIndexEnabled"], True)
+
+ def test_without_record_no_ancestors_returns_defaults(self):
+ allFds: list = []
+ result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds)
+ self.assertEqual(result["effectiveNeutralize"], False)
+ self.assertEqual(result["effectiveScope"], "personal")
+ self.assertEqual(result["effectiveRagIndexEnabled"], False)
+
+ def test_rag_inherits_when_table_overrides_neutralize_only(self):
+ """Tables that override only neutralize must still inherit RAG from parent."""
+ ws = _fds("ws", tableName="*", ragIndexEnabled=True)
+ tbl = _fds("t", tableName="Pos", neutralize=False)
+ allFds = [ws, tbl]
+ result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds)
+ self.assertEqual(result["effectiveRagIndexEnabled"], True)
+
+ def test_rag_aggregate_mixed_when_descendants_diverge(self):
+ ws = _fds("ws", tableName="*", ragIndexEnabled=True)
+ tbl = _fds("t", tableName="Pos", ragIndexEnabled=False)
+ allFds = [ws, tbl]
+ result = _inheritFlags.resolveEffectiveForFds("fi-1", "*", None, allFds, mode="aggregate")
+ self.assertEqual(result["effectiveRagIndexEnabled"], "mixed")
+
+ def test_inheritable_fds_flags_includes_rag(self):
+ self.assertIn("ragIndexEnabled", _inheritFlags._INHERITABLE_FDS_FLAGS)
+ self.assertIn("neutralize", _inheritFlags._INHERITABLE_FDS_FLAGS)
+ self.assertIn("scope", _inheritFlags._INHERITABLE_FDS_FLAGS)
+
+
+# ===========================================================================
+# FDS cascade resets RAG (in addition to neutralize and scope)
+# ===========================================================================
+
+class TestCascadeResetFdsRag(unittest.TestCase):
+ def test_cascade_resets_rag_on_descendants(self):
+ ws = _fds("ws", tableName="*")
+ tbl = _fds("t", tableName="Pos", ragIndexEnabled=False)
+ allFds = [ws, tbl]
+ rootIf = MagicMock()
+ rootIf.db.getRecordset.return_value = allFds
+ rootIf.db.recordModify = MagicMock()
+ result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled")
+ self.assertIn("t", result)
+ rootIf.db.recordModify.assert_called()
+
+
+# ===========================================================================
+# Path normalization
+# ===========================================================================
+
class TestPathNormalization(unittest.TestCase):
def test_empty_path_normalises_to_root(self):
self.assertEqual(_inheritFlags._normalisePath(""), "/")
diff --git a/tests/unit/teamsbot/test_directorPrompts.py b/tests/unit/teamsbot/test_directorPrompts.py
index f136438a..b8bdaafc 100644
--- a/tests/unit/teamsbot/test_directorPrompts.py
+++ b/tests/unit/teamsbot/test_directorPrompts.py
@@ -42,7 +42,7 @@ from modules.features.teamsbot.datamodelTeamsbot import (
from modules.features.teamsbot.service import (
TeamsbotService,
_activeServices,
- _sessionEvents,
+ sessionEvents,
getActiveService,
)
@@ -152,10 +152,10 @@ def _buildService() -> TeamsbotService:
def _resetGlobals():
"""Avoid cross-test bleed in module-level globals."""
_activeServices.clear()
- _sessionEvents.clear()
+ sessionEvents.clear()
yield
_activeServices.clear()
- _sessionEvents.clear()
+ sessionEvents.clear()
# ============================================================================
@@ -251,7 +251,7 @@ class TestBuildPersistentDirectorContext:
]
rendered = svc._buildPersistentDirectorContext()
assert "OPERATOR_DIRECTIVES" in rendered
- assert "- Antworte immer in Englisch." in rendered
+ assert "Antworte immer in Englisch." in rendered
assert "private" in rendered
def test_skipsBlankText(self):
@@ -261,7 +261,7 @@ class TestBuildPersistentDirectorContext:
{"id": "p2", "text": "Sei hoeflich."},
]
rendered = svc._buildPersistentDirectorContext()
- assert "- Sei hoeflich." in rendered
+ assert "Sei hoeflich." in rendered
assert "p1" not in rendered # the blank one is filtered out
def test_allBlankPromptsResultInEmpty(self):
From 9773c00bca3b48216c86adb6e733ae1060f37851 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Tue, 19 May 2026 17:38:18 +0200
Subject: [PATCH 4/6] trustee budget fix
---
modules/features/trustee/mainTrustee.py | 10 ++-
.../actions/refreshAccountingData.py | 83 +++++++++++++++----
2 files changed, 74 insertions(+), 19 deletions(-)
diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py
index b3f7cdcf..41903211 100644
--- a/modules/features/trustee/mainTrustee.py
+++ b/modules/features/trustee/mainTrustee.py
@@ -484,8 +484,14 @@ TEMPLATE_WORKFLOWS = [
"3. Kurzer Management-Summary-Absatz (3-5 Saetze) UNTER dem Chart "
"mit den 3 groessten Abweichungen (>10%) und einer fachlichen "
"Einschaetzung.\n\n"
- "Verwende die uebergebene Budget-Datei als Soll-Quelle und die im "
- "Kontext bereitgestellten Buchhaltungsdaten als Ist-Quelle.\n"
+ "DATENQUELLEN:\n"
+ "- SOLL (Budget): Aus der uebergebenen Budget-Datei (Excel).\n"
+ "- IST (Buchhaltung): Verwende AUSSCHLIESSLICH das Feld "
+ "\"closingBalance\" aus \"accountSummary\" im Kontext-JSON. "
+ "Dort steht pro Konto GENAU EIN Ist-Wert (Jahresabschluss-Saldo). "
+ "Fuer Quartals-Budgets stehen zusaetzlich Q1/Q2/Q3/Q4-Felder bereit. "
+ "SUMMIERE NIEMALS mehrere Zeilen oder Journal-Eintraege auf -- der "
+ "closingBalance in accountSummary ist bereits der korrekte Ist-Wert.\n\n"
"WICHTIG: Erstelle KEINEN separaten Chart pro Konto. Nur EIN "
"Uebersichts-Chart ueber alle Konten ist gewuenscht.\n\n"
"Hinweis: Das documentTheme ist 'finance'. Wenn du ein Dokument erstellst, "
diff --git a/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py b/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py
index 6ff5641c..0d6e737c 100644
--- a/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py
+++ b/modules/workflows/methods/methodTrustee/actions/refreshAccountingData.py
@@ -38,6 +38,52 @@ def _tsToIso(ts) -> Optional[str]:
_SYNC_THRESHOLD_SECONDS = 3600
+def _buildAccountSummary(accountMap: Dict[str, dict], balances: list, year: int) -> list:
+ """Aggregate balance records into one row per account for *year*.
+
+ For each account the annual balance record (``periodMonth == 0``) of
+ *year* is preferred. If that row is missing, we also check the
+ previous year's annual record so that YTD carry-forwards are visible.
+ Additionally, quarterly closing balances (Q1-Q4) are derived from the
+ monthly records so the AI can compare against quarterly budgets.
+ """
+ bestClosing: Dict[str, float] = {}
+ quarterClosing: Dict[str, Dict[str, float]] = {}
+
+ for b in balances:
+ acct = b.get("accountNumber", "")
+ bYear = b.get("periodYear", 0)
+ bMonth = b.get("periodMonth", 0)
+ closing = b.get("closingBalance", 0) or 0
+
+ if bYear == year and bMonth == 0:
+ bestClosing[acct] = closing
+
+ if bYear == year and bMonth in (3, 6, 9, 12):
+ qLabel = f"Q{bMonth // 3}"
+ quarterClosing.setdefault(acct, {})[qLabel] = closing
+
+ if acct not in bestClosing and bYear == year - 1 and bMonth == 0:
+ bestClosing[acct] = closing
+
+ summary = []
+ for nr in sorted(accountMap.keys()):
+ info = accountMap[nr]
+ row = {
+ "account": nr,
+ "label": info.get("label", ""),
+ "type": info.get("type", ""),
+ "group": info.get("group", ""),
+ "closingBalance": round(bestClosing.get(nr, 0), 2),
+ }
+ qData = quarterClosing.get(nr, {})
+ for q in ("Q1", "Q2", "Q3", "Q4"):
+ if q in qData:
+ row[q] = round(qData[q], 2)
+ summary.append(row)
+ return summary
+
+
async def refreshAccountingData(self, parameters: Dict[str, Any]) -> ActionResult:
"""Import/refresh accounting data from the configured external system.
@@ -133,7 +179,13 @@ async def refreshAccountingData(self, parameters: Dict[str, Any]) -> ActionResul
def _exportAccountingData(trusteeInterface, featureInstanceId: str, dateFrom: str = None, dateTo: str = None) -> str:
- """Export accounting data (accounts, balances, journal entries+lines) as compact JSON for downstream AI nodes."""
+ """Export accounting data as compact JSON for downstream AI nodes.
+
+ Produces a pre-aggregated ``accountSummary`` (one row per account with
+ a single *Ist* value) so the AI does not have to navigate thousands of
+ raw balance records. Raw per-month balances are deliberately omitted to
+ avoid confusion and reduce payload size.
+ """
from modules.features.trustee.datamodelFeatureTrustee import (
TrusteeDataAccount,
TrusteeDataJournalEntry,
@@ -155,17 +207,9 @@ def _exportAccountingData(trusteeInterface, featureInstanceId: str, dateFrom: st
}
balances = trusteeInterface.db.getRecordset(TrusteeDataAccountBalance, recordFilter=baseFilter) or []
- balanceList = []
- for b in balances:
- balanceList.append({
- "account": b.get("accountNumber", ""),
- "year": b.get("periodYear", 0),
- "month": b.get("periodMonth", 0),
- "opening": b.get("openingBalance", 0),
- "debit": b.get("debitTotal", 0),
- "credit": b.get("creditTotal", 0),
- "closing": b.get("closingBalance", 0),
- })
+
+ currentYear = _dt.now(tz=_tz.utc).year
+ accountSummary = _buildAccountSummary(accountMap, balances, currentYear)
entries = trusteeInterface.db.getRecordset(TrusteeDataJournalEntry, recordFilter=baseFilter) or []
fromTs = _isoToTs(dateFrom)
@@ -205,21 +249,26 @@ def _exportAccountingData(trusteeInterface, featureInstanceId: str, dateFrom: st
})
export = {
- "accounts": list(accountMap.values()),
- "balances": balanceList,
+ "accountSummary": accountSummary,
"journalLines": lineList,
"meta": {
"accountCount": len(accountMap),
"entryCount": len(entryMap),
"lineCount": len(lineList),
- "balanceCount": len(balanceList),
+ "summaryYear": currentYear,
"dateFrom": dateFrom,
"dateTo": dateTo,
+ "hint": (
+ "accountSummary contains ONE row per account with the "
+ "current-year closing balance (Ist). Use this for "
+ "budget comparisons. journalLines lists individual "
+ "bookings for drill-down."
+ ),
},
}
result = json.dumps(export, ensure_ascii=False, default=str)
- logger.info("Exported accounting data: %d accounts, %d entries, %d lines, %d balances (%d bytes)",
- len(accountMap), len(entryMap), len(lineList), len(balanceList), len(result))
+ logger.info("Exported accounting data: %d accounts (summary), %d entries, %d lines (%d bytes)",
+ len(accountSummary), len(entryMap), len(lineList), len(result))
return result
except Exception as e:
logger.warning("Could not export accounting data: %s", e)
From a173fab15ff3d9fde82a44874018c4c919f47e45 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Tue, 19 May 2026 17:42:24 +0200
Subject: [PATCH 5/6] fix mandate res
---
modules/routes/routeHelpers.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/modules/routes/routeHelpers.py b/modules/routes/routeHelpers.py
index bb1386af..b58ffc6d 100644
--- a/modules/routes/routeHelpers.py
+++ b/modules/routes/routeHelpers.py
@@ -41,7 +41,7 @@ def resolveMandateLabels(ids: List[str]) -> Dict[str, Optional[str]]:
m = mMap.get(mid)
label = (getattr(m, "label", None) or getattr(m, "name", None)) if m else None
if not label:
- logger.warning("resolveMandateLabels: no label for id=%s (found=%s)", mid, m is not None)
+ logger.debug("resolveMandateLabels: no label for id=%s (found=%s)", mid, m is not None)
result[mid] = label or None
return result
@@ -57,7 +57,7 @@ def resolveInstanceLabels(ids: List[str]) -> Dict[str, Optional[str]]:
fi = featureIface.getFeatureInstance(iid)
label = fi.label if fi and fi.label else None
if not label:
- logger.warning("resolveInstanceLabels: no label for id=%s (found=%s)", iid, fi is not None)
+ logger.debug("resolveInstanceLabels: no label for id=%s (found=%s)", iid, fi is not None)
result[iid] = label
return result
@@ -104,7 +104,7 @@ def resolveRoleLabels(ids: List[str]) -> Dict[str, Optional[str]]:
out[rid] = r.get("roleLabel") or None
for rid in ids:
if out.get(rid) is None:
- logger.warning("resolveRoleLabels: no label for id=%s", rid)
+ logger.debug("resolveRoleLabels: no label for id=%s", rid)
return out
From 09c6d33deca58e715c8e6256a2c9ed0021eee44d Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Tue, 19 May 2026 22:14:00 +0200
Subject: [PATCH 6/6] fixed expenses workflow
---
.../mainServiceSharepoint.py | 27 +++++++------------
.../methodTrustee/actions/processDocuments.py | 21 ++++++++++++---
2 files changed, 27 insertions(+), 21 deletions(-)
diff --git a/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py b/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py
index 483d7fbe..4fd1fb36 100644
--- a/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py
+++ b/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py
@@ -327,27 +327,20 @@ class SharepointService:
return None
async def uploadFile(self, siteId: str, folderPath: str, fileName: str, content: bytes) -> Dict[str, Any]:
- """Upload a file to SharePoint."""
- try:
- # Clean the path
- cleanPath = folderPath.lstrip('/')
- uploadPath = f"{cleanPath.rstrip('/')}/{fileName}"
- endpoint = f"sites/{siteId}/drive/root:/{uploadPath}:/content"
+ """Upload a file to SharePoint. Raises on failure."""
+ cleanPath = folderPath.lstrip('/')
+ uploadPath = f"{cleanPath.rstrip('/')}/{fileName}"
+ endpoint = f"sites/{siteId}/drive/root:/{uploadPath}:/content"
- logger.info(f"Uploading file to: {endpoint}")
+ logger.info(f"Uploading file to: {endpoint}")
- result = await self._makeGraphApiCall(endpoint, method="PUT", data=content)
+ result = await self._makeGraphApiCall(endpoint, method="PUT", data=content)
- if "error" in result:
- logger.error(f"Upload failed: {result['error']}")
- return result
+ if "error" in result:
+ raise Exception(f"Upload failed: {result['error']}")
- logger.info(f"File uploaded successfully: {fileName}")
- return result
-
- except Exception as e:
- logger.error(f"Error uploading file: {str(e)}")
- return {"error": f"Error uploading file: {str(e)}"}
+ logger.info(f"File uploaded successfully: {fileName}")
+ return result
async def downloadFile(self, siteId: str, fileId: str) -> Optional[bytes]:
"""Download a file from SharePoint."""
diff --git a/modules/workflows/methods/methodTrustee/actions/processDocuments.py b/modules/workflows/methods/methodTrustee/actions/processDocuments.py
index b05e25f4..29d5ab13 100644
--- a/modules/workflows/methods/methodTrustee/actions/processDocuments.py
+++ b/modules/workflows/methods/methodTrustee/actions/processDocuments.py
@@ -247,16 +247,29 @@ def _resolveDocumentList(documentListParam, services) -> List[tuple]:
if isinstance(first, dict) and ("documentData" in first or "documentName" in first):
for doc in documentListParam:
rawData = doc.get("documentData")
- logger.debug("_resolveDocumentList: doc keys=%s documentData type=%s documentData truthy=%s", list(doc.keys()), type(rawData).__name__, bool(rawData))
+ fileId = (doc.get("validationMetadata") or {}).get("fileId") or doc.get("fileId", "")
+ fileName = doc.get("documentName") or doc.get("fileName") or "document"
+ mimeType = doc.get("mimeType") or doc.get("documentMimeType") or "application/json"
+
+ # When documentData was persisted as binary (_hasBinaryData), read it
+ # back from file storage via the chat service.
+ if not rawData and doc.get("_hasBinaryData") and fileId:
+ chatService = getattr(services, "chat", None)
+ if chatService:
+ try:
+ rawBytes = chatService.getFileData(fileId)
+ if rawBytes:
+ rawData = rawBytes.decode("utf-8") if isinstance(rawBytes, bytes) else rawBytes
+ except Exception as e:
+ logger.debug("_resolveDocumentList: failed to read binary for fileId=%s: %s", fileId, e)
+
+ logger.debug("_resolveDocumentList: doc keys=%s documentData type=%s documentData truthy=%s", list(doc.keys()), type(rawData).__name__ if rawData else "NoneType", bool(rawData))
if not rawData:
continue
try:
data = json.loads(rawData) if isinstance(rawData, str) else rawData
except (json.JSONDecodeError, TypeError):
continue
- fileId = (doc.get("validationMetadata") or {}).get("fileId") or doc.get("fileId", "")
- fileName = doc.get("documentName") or doc.get("fileName") or "document"
- mimeType = doc.get("mimeType") or doc.get("documentMimeType") or "application/json"
results.append((data, fileId, fileName, mimeType))
if results:
return results