fix:performance improvements

- app.py: Pre-warm AI connectors at module load and in lifespan
- aicoreModelRegistry.py: Connector discovery cache, getAvailableModels cache, bulk RBAC, eager prewarm
- connectorDbPostgre.py: Connector cache, contextvars for userId, eviction (max 32)
- chatbot: Uses _get_cached_connector, Service center integration, BillingService exceptions, BillingService exceptions instead of direct imports
- interfaceDbApp.py: Uses _get_cached_connector
- interfaceDbManagement.py: Uses _get_cached_connector
- security/rbac.py: Adds checkResourceAccessBulk
This commit is contained in:
Ida Dittrich 2026-03-06 13:46:54 +01:00
parent 42e79a724a
commit 6dc2afafb9
13 changed files with 1317 additions and 192 deletions

19
app.py
View file

@ -280,12 +280,29 @@ initLogging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
# Pre-warm AI connectors on process load (before lifespan). Critical for chatbot latency.
try:
import modules.aicore.aicoreModelRegistry # noqa: F401
logger.info("AI connectors pre-warm (app load) triggered")
except Exception as e:
logging.getLogger(__name__).warning(f"AI pre-warm at app load failed: {e}")
# Define lifespan context manager for application startup/shutdown events # Define lifespan context manager for application startup/shutdown events
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Application is starting up") logger.info("Application is starting up")
# --- Pre-warm AI connectors FIRST (before any other startup work) ---
# Avoids 48 s latency on first chatbot request; must run before first use.
try:
import modules.aicore.aicoreModelRegistry # noqa: F401 - triggers eager pre-warm
from modules.aicore.aicoreModelRegistry import modelRegistry
modelRegistry.ensureConnectorsRegistered()
modelRegistry.refreshModels(force=True)
logger.info("AI connectors and model registry pre-warmed")
except Exception as e:
logger.warning(f"AI pre-warm failed: {e}")
# Bootstrap database if needed (creates initial users, mandates, roles, etc.) # Bootstrap database if needed (creates initial users, mandates, roles, etc.)
# This must happen before getting root interface # This must happen before getting root interface
from modules.security.rootAccess import getRootDbAppConnector from modules.security.rootAccess import getRootDbAppConnector
@ -333,7 +350,7 @@ async def lifespan(app: FastAPI):
# Register audit log cleanup scheduler # Register audit log cleanup scheduler
from modules.shared.auditLogger import registerAuditLogCleanupScheduler from modules.shared.auditLogger import registerAuditLogCleanupScheduler
registerAuditLogCleanupScheduler() registerAuditLogCleanupScheduler()
# Ensure billing settings and accounts exist for all mandates # Ensure billing settings and accounts exist for all mandates
try: try:
from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface

View file

@ -8,7 +8,8 @@ Implements plugin-like architecture for connector discovery.
import logging import logging
import importlib import importlib
import os import os
from typing import Dict, List, Optional, Any import time
from typing import Dict, List, Optional, Any, Tuple
from modules.datamodels.datamodelAi import AiModel from modules.datamodels.datamodelAi import AiModel
from .aicoreBase import BaseConnectorAi from .aicoreBase import BaseConnectorAi
from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelUam import User
@ -31,6 +32,9 @@ class ModelRegistry:
self._lastRefresh: Optional[float] = None self._lastRefresh: Optional[float] = None
self._refreshInterval: float = 300.0 # 5 minutes self._refreshInterval: float = 300.0 # 5 minutes
self._connectorsInitialized: bool = False self._connectorsInitialized: bool = False
self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call
self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts)
self._getAvailableModelsCacheTtl: float = 30.0 # seconds
def registerConnector(self, connector: BaseConnectorAi): def registerConnector(self, connector: BaseConnectorAi):
"""Register a connector and collect its models.""" """Register a connector and collect its models."""
@ -68,34 +72,38 @@ class ModelRegistry:
raise raise
def discoverConnectors(self) -> List[BaseConnectorAi]: def discoverConnectors(self) -> List[BaseConnectorAi]:
"""Auto-discover connectors by scanning aicorePlugin*.py files.""" """Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use."""
if self._discoveredConnectorsCache is not None:
return self._discoveredConnectorsCache
connectors = [] connectors = []
connectorDir = os.path.dirname(__file__) connectorDir = os.path.dirname(__file__)
# Scan for connector files # Scan for connector files
for filename in os.listdir(connectorDir): for filename in os.listdir(connectorDir):
if filename.startswith('aicorePlugin') and filename.endswith('.py'): if filename.startswith('aicorePlugin') and filename.endswith('.py'):
moduleName = filename[:-3] # Remove .py extension moduleName = filename[:-3] # Remove .py extension
try: try:
# Import the module # Import the module
module = importlib.import_module(f'modules.aicore.{moduleName}') module = importlib.import_module(f'modules.aicore.{moduleName}')
# Find connector classes (classes that inherit from BaseConnectorAi) # Find connector classes (classes that inherit from BaseConnectorAi)
for attrName in dir(module): for attrName in dir(module):
attr = getattr(module, attrName) attr = getattr(module, attrName)
if (isinstance(attr, type) and if (isinstance(attr, type) and
issubclass(attr, BaseConnectorAi) and issubclass(attr, BaseConnectorAi) and
attr != BaseConnectorAi): attr != BaseConnectorAi):
# Instantiate the connector # Instantiate the connector
connector = attr() connector = attr()
connectors.append(connector) connectors.append(connector)
logger.info(f"Discovered connector: {connector.getConnectorType()}") logger.info(f"Discovered connector: {connector.getConnectorType()}")
except Exception as e: except Exception as e:
logger.warning(f"Failed to discover connector from {filename}: {e}") logger.warning(f"Failed to discover connector from {filename}: {e}")
self._discoveredConnectorsCache = connectors
return connectors return connectors
def ensureConnectorsRegistered(self): def ensureConnectorsRegistered(self):
@ -175,24 +183,49 @@ class ModelRegistry:
self.refreshModels() self.refreshModels()
return [model for model in self._models.values() if model.priority == priority] return [model for model in self._models.values() if model.priority == priority]
def getAvailableModels(self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> List[AiModel]: def getAvailableModels(
self,
currentUser: Optional[User] = None,
rbacInstance: Optional[RbacClass] = None,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None
) -> List[AiModel]:
"""Get only available models, optionally filtered by RBAC permissions. """Get only available models, optionally filtered by RBAC permissions.
Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call.
Args: Args:
currentUser: Optional user object for RBAC filtering currentUser: Optional user object for RBAC filtering
rbacInstance: Optional RBAC instance for permission checks rbacInstance: Optional RBAC instance for permission checks
mandateId: Optional mandate context for faster RBAC (loads fewer roles)
featureInstanceId: Optional feature instance for RBAC context
Returns: Returns:
List of available models (filtered by RBAC if user provided) List of available models (filtered by RBAC if user provided)
""" """
self.refreshModels() self.refreshModels()
cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0)
now = time.time()
if cache_key in self._getAvailableModelsCache:
cached_models, cached_ts = self._getAvailableModelsCache[cache_key]
if now - cached_ts < self._getAvailableModelsCacheTtl:
logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...")
return cached_models
allModels = list(self._models.values()) allModels = list(self._models.values())
availableModels = [model for model in allModels if model.isAvailable] availableModels = [model for model in allModels if model.isAvailable]
# Apply RBAC filtering if user and RBAC instance provided # Apply RBAC filtering if user and RBAC instance provided (batch check for performance)
if currentUser and rbacInstance: if currentUser and rbacInstance:
availableModels = self._filterModelsByRbac(availableModels, currentUser, rbacInstance) availableModels = self._filterModelsByRbac(
availableModels, currentUser, rbacInstance, mandateId, featureInstanceId
)
self._getAvailableModelsCache[cache_key] = (availableModels, now)
# Prune expired entries to avoid unbounded growth
expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl]
for k in expired:
del self._getAvailableModelsCache[k]
unavailableCount = len(allModels) - len(availableModels) unavailableCount = len(allModels) - len(availableModels)
if unavailableCount > 0: if unavailableCount > 0:
unavailableModels = [m.name for m in allModels if not m.isAvailable] unavailableModels = [m.name for m in allModels if not m.isAvailable]
@ -200,32 +233,33 @@ class ModelRegistry:
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}") logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
return availableModels return availableModels
def _filterModelsByRbac(self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass) -> List[AiModel]: def _filterModelsByRbac(
"""Filter models based on RBAC permissions. self,
models: List[AiModel],
Args: currentUser: User,
models: List of models to filter rbacInstance: RbacClass,
currentUser: Current user object mandateId: Optional[str] = None,
rbacInstance: RBAC instance for permission checks featureInstanceId: Optional[str] = None
) -> List[AiModel]:
Returns: """Filter models based on RBAC permissions. Uses bulk check for performance."""
Filtered list of models that user has access to paths = []
""" model_paths = {} # model -> (connector_path, model_path)
for model in models:
connector_path = f"ai.model.{model.connectorType}"
model_path = f"ai.model.{model.connectorType}.{model.displayName}"
paths.extend([connector_path, model_path])
model_paths[id(model)] = (connector_path, model_path)
# Single bulk RBAC call instead of 2*N per-model calls
access = rbacInstance.checkResourceAccessBulk(
currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId
)
filteredModels = [] filteredModels = []
for model in models: for model in models:
# Check access at both connector level and model level connector_path, model_path = model_paths[id(model)]
connectorResourcePath = f"ai.model.{model.connectorType}" if access.get(connector_path, False) or access.get(model_path, False):
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
# User needs access to either connector (all models) or specific model
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
if hasConnectorAccess or hasModelAccess:
filteredModels.append(model) filteredModels.append(model)
else: else:
logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})") logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})")
return filteredModels return filteredModels
def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]: def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]:
@ -305,3 +339,17 @@ class ModelRegistry:
# Global registry instance # Global registry instance
modelRegistry = ModelRegistry() modelRegistry = ModelRegistry()
# Eager pre-warm on first import: ensures connectors are ready in this process.
# Critical for chatbot performance — avoids 48 s latency on first request.
# Runs when this module is first imported (lifespan or first chatbot request).
def _eager_prewarm() -> None:
try:
modelRegistry.ensureConnectorsRegistered()
modelRegistry.refreshModels(force=True)
logger.info("AI connectors and model registry pre-warmed (module load)")
except Exception as e:
logger.warning(f"AI eager pre-warm skipped: {e}")
_eager_prewarm()

View file

@ -1,5 +1,6 @@
# Copyright (c) 2025 Patrick Motsch # Copyright (c) 2025 Patrick Motsch
# All rights reserved. # All rights reserved.
import contextvars
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
import logging import logging
@ -99,7 +100,56 @@ def _get_model_fields(model_class) -> Dict[str, str]:
return fields return fields
# No caching needed with proper database # Cache connectors by (host, database, port) to avoid duplicate inits for same database.
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
# contextvars to avoid races when concurrent requests share the same connector.
_MAX_CACHED_CONNECTORS = 32
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
_connector_cache_lock = threading.Lock()
_current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"db_connector_user_id", default=None
)
def _get_cached_connector(
dbHost: str,
dbDatabase: str,
dbUser: str = None,
dbPassword: str = None,
dbPort: int = None,
userId: str = None,
) -> "DatabaseConnector":
"""Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
Uses contextvars for userId so concurrent requests sharing the same connector get correct _createdBy/_modifiedBy.
"""
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
with _connector_cache_lock:
if key not in _connector_cache:
# Evict oldest if at capacity
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
oldest_key = _connector_cache_order.pop(0)
if oldest_key in _connector_cache:
try:
_connector_cache[oldest_key].close()
except Exception as e:
logger.warning(f"Error closing evicted connector: {e}")
del _connector_cache[oldest_key]
_connector_cache[key] = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
dbUser=dbUser,
dbPassword=dbPassword,
dbPort=dbPort,
userId=userId,
)
_connector_cache_order.append(key)
conn = _connector_cache[key]
# Set request-scoped userId via contextvar (avoids mutating shared connector)
if userId is not None:
_current_user_id.set(userId)
return conn
class DatabaseConnector: class DatabaseConnector:
@ -645,24 +695,22 @@ class DatabaseConnector:
if "id" in record and str(record["id"]) != recordId: if "id" in record and str(record["id"]) != recordId:
raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}") raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}")
# Add metadata # Add metadata - use contextvar for request-scoped userId when sharing connector
effective_user_id = _current_user_id.get()
if effective_user_id is None:
effective_user_id = self.userId
currentTime = getUtcTimestamp() currentTime = getUtcTimestamp()
# Set _createdAt and _createdBy if this is a new record (record doesn't have _createdAt) # Set _createdAt and _createdBy if this is a new record (record doesn't have _createdAt)
if "_createdAt" not in record: if "_createdAt" not in record:
record["_createdAt"] = currentTime record["_createdAt"] = currentTime
# Only set _createdBy if userId is valid (not None or empty string) if effective_user_id:
if self.userId: record["_createdBy"] = effective_user_id
record["_createdBy"] = self.userId
# No warning - empty userId is normal during bootstrap
# Also ensure _createdBy is set even if _createdAt exists but _createdBy is missing/empty
elif "_createdBy" not in record or not record.get("_createdBy"): elif "_createdBy" not in record or not record.get("_createdBy"):
if self.userId: if effective_user_id:
record["_createdBy"] = self.userId record["_createdBy"] = effective_user_id
# No warning - empty userId is normal during bootstrap
# Always update modification metadata
record["_modifiedAt"] = currentTime record["_modifiedAt"] = currentTime
if self.userId: if effective_user_id:
record["_modifiedBy"] = self.userId record["_modifiedBy"] = effective_user_id
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
self._save_record(cursor, table, recordId, record, model_class) self._save_record(cursor, table, recordId, record, model_class)
@ -782,12 +830,13 @@ class DatabaseConnector:
return False return False
def updateContext(self, userId: str) -> None: def updateContext(self, userId: str) -> None:
"""Updates the context of the database connector.""" """Updates the context of the database connector.
Sets both instance userId and contextvar for request-scoped use when connector is shared.
"""
if userId is None: if userId is None:
raise ValueError("userId must be provided") raise ValueError("userId must be provided")
self.userId = userId self.userId = userId
# No cache to clear - database handles data consistency _current_user_id.set(userId)
# Public API # Public API

View file

@ -38,9 +38,10 @@ from modules.datamodels.datamodelUam import User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Workflow-level store for allowed_providers (survives LangGraph/bind_tools execution context # Workflow-level store for allowed_providers and RBAC context (survives LangGraph/bind_tools
# where instance attributes may be lost when model is wrapped or serialized) # execution context where instance attributes may be lost when model is wrapped or serialized)
_workflow_allowed_providers: Dict[str, List[str]] = {} _workflow_allowed_providers: Dict[str, List[str]] = {}
_workflow_rbac_context: Dict[str, tuple] = {} # workflow_id -> (mandateId, featureInstanceId)
def clear_workflow_allowed_providers(workflow_id: str) -> None: def clear_workflow_allowed_providers(workflow_id: str) -> None:
@ -62,11 +63,14 @@ class AICenterChatModel(BaseChatModel):
billing_callback: Optional[Callable[[AiCallResponse], None]] = None, billing_callback: Optional[Callable[[AiCallResponse], None]] = None,
workflow_id: Optional[str] = None, workflow_id: Optional[str] = None,
allowed_providers: Optional[List[str]] = None, allowed_providers: Optional[List[str]] = None,
prefer_fast_model: bool = False,
mandate_id: Optional[str] = None,
feature_instance_id: Optional[str] = None,
**kwargs **kwargs
): ):
""" """
Initialize the AI center chat model bridge. Initialize the AI center chat model bridge.
Args: Args:
user: Current user for RBAC and model selection user: Current user for RBAC and model selection
operation_type: Operation type for model selection operation_type: Operation type for model selection
@ -74,6 +78,7 @@ class AICenterChatModel(BaseChatModel):
billing_callback: Optional callback invoked after each _agenerate with AiCallResponse for billing billing_callback: Optional callback invoked after each _agenerate with AiCallResponse for billing
workflow_id: Optional workflow/conversation ID for billing context workflow_id: Optional workflow/conversation ID for billing context
allowed_providers: Optional list of allowed provider connector types (empty/None = all) allowed_providers: Optional list of allowed provider connector types (empty/None = all)
prefer_fast_model: When True, strongly prefer faster models (e.g. gpt-4o-mini for planner)
**kwargs: Additional arguments passed to BaseChatModel **kwargs: Additional arguments passed to BaseChatModel
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
@ -85,9 +90,14 @@ class AICenterChatModel(BaseChatModel):
object.__setattr__(self, "_billing_callback", billing_callback) object.__setattr__(self, "_billing_callback", billing_callback)
object.__setattr__(self, "_workflow_id", workflow_id) object.__setattr__(self, "_workflow_id", workflow_id)
object.__setattr__(self, "_allowed_providers", allowed_providers or []) object.__setattr__(self, "_allowed_providers", allowed_providers or [])
object.__setattr__(self, "_prefer_fast_model", prefer_fast_model)
object.__setattr__(self, "_mandate_id", mandate_id)
object.__setattr__(self, "_feature_instance_id", feature_instance_id)
# Store in workflow-level registry so it survives when instance attrs are lost (e.g. bind_tools) # Store in workflow-level registry so it survives when instance attrs are lost (e.g. bind_tools)
if workflow_id and allowed_providers: if workflow_id and allowed_providers:
_workflow_allowed_providers[workflow_id] = list(allowed_providers) _workflow_allowed_providers[workflow_id] = list(allowed_providers)
if workflow_id and (mandate_id is not None or feature_instance_id is not None):
_workflow_rbac_context[workflow_id] = (mandate_id, feature_instance_id)
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -129,17 +139,25 @@ class AICenterChatModel(BaseChatModel):
# Get available models with RBAC filtering # Get available models with RBAC filtering
# Use cached/singleton interfaces for better performance # Use cached/singleton interfaces for better performance
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
workflow_id = getattr(self, "_workflow_id", None)
rootInterface = getRootInterface() rootInterface = getRootInterface()
rbac_instance = rootInterface.rbac rbac_instance = rootInterface.rbac
mandate_id = getattr(self, "_mandate_id", None)
feature_instance_id = getattr(self, "_feature_instance_id", None)
if workflow_id and (mandate_id is None and feature_instance_id is None):
ctx = _workflow_rbac_context.get(workflow_id)
if ctx:
mandate_id, feature_instance_id = ctx
available_models = modelRegistry.getAvailableModels( available_models = modelRegistry.getAvailableModels(
currentUser=self.user, currentUser=self.user,
rbacInstance=rbac_instance rbacInstance=rbac_instance,
mandateId=mandate_id,
featureInstanceId=feature_instance_id,
) )
# Allowed providers: instance attr or workflow store (lost in LangGraph/bind_tools context) # Allowed providers: instance attr or workflow store (lost in LangGraph/bind_tools context)
workflow_id = getattr(self, '_workflow_id', None)
allowed = ( allowed = (
(_workflow_allowed_providers.get(workflow_id) if workflow_id else None) (_workflow_allowed_providers.get(workflow_id) if workflow_id else None)
or getattr(self, '_allowed_providers', None) or getattr(self, '_allowed_providers', None)
@ -155,7 +173,8 @@ class AICenterChatModel(BaseChatModel):
options = AiCallOptions( options = AiCallOptions(
operationType=self.operation_type, operationType=self.operation_type,
processingMode=self.processing_mode, processingMode=self.processing_mode,
allowedProviders=allowed if allowed else None allowedProviders=allowed if allowed else None,
preferFastModel=getattr(self, "_prefer_fast_model", False),
) )
# Select model # Select model
@ -246,7 +265,97 @@ class AICenterChatModel(BaseChatModel):
# Run the async method synchronously # Run the async method synchronously
return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)) return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs))
async def _call_openai_streaming(
self,
ai_messages: List[dict],
run_manager: Optional[Any],
model_call: "AiModelCall",
input_bytes: int,
start_time: float,
) -> "AiModelResponse":
"""Call OpenAI/Ollama with stream=True, emit tokens via run_manager, return full response."""
import httpx
import json as _json
from modules.shared.configuration import APP_CONFIG
if self._selected_model.connectorType == "openai":
api_url = getattr(self._selected_model, "apiUrl", None) or "https://api.openai.com/v1/chat/completions"
api_key = APP_CONFIG.get("Connector_AiOpenai_API_SECRET")
if not api_key:
raise ValueError("OpenAI API key not configured")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
ollama_model = self._selected_model.name
else:
base_url = getattr(self._selected_model, "apiUrl", "").replace("/api/analyze", "")
api_url = f"{base_url.rstrip('/')}/v1/chat/completions"
api_key = APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET")
headers = {"Content-Type": "application/json"}
if api_key:
headers["X-API-Key"] = api_key
ollama_model = getattr(self._selected_model, "version", None) or self._selected_model.name
payload = {
"model": ollama_model,
"messages": ai_messages,
"temperature": self._selected_model.temperature,
"max_tokens": self._selected_model.maxTokens,
"stream": True,
}
content_parts: List[str] = []
async with httpx.AsyncClient(timeout=600.0) as client:
async with client.stream("POST", api_url, headers=headers, json=payload) as resp:
if resp.status_code != 200:
raise ValueError(f"OpenAI stream error: {resp.status_code} - {await resp.aread()}")
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer or "\r\n" in buffer:
line, _, buffer = buffer.partition("\n")
line = line.strip()
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
break
try:
data = _json.loads(data_str)
choices = data.get("choices") or []
if choices:
delta = choices[0].get("delta") or {}
token = delta.get("content") or ""
if token and run_manager and hasattr(run_manager, "on_llm_new_token"):
run_manager.on_llm_new_token(token)
content_parts.append(token)
except _json.JSONDecodeError:
pass
content = "".join(content_parts)
processing_time = time.time() - start_time
output_bytes = len(content.encode("utf-8"))
price_chf = 0.0
if getattr(self._selected_model, "calculatepriceCHF", None):
try:
price_chf = self._selected_model.calculatepriceCHF(processing_time, input_bytes, output_bytes)
except Exception:
pass
billing_callback = getattr(self, "_billing_callback", None)
if billing_callback:
try:
billing_callback(AiCallResponse(
content=content,
modelName=self._selected_model.name,
provider=self._selected_model.connectorType or "unknown",
priceCHF=price_chf,
processingTime=processing_time,
bytesSent=input_bytes,
bytesReceived=output_bytes,
errorCount=0,
))
except Exception as e:
logger.error(f"Billing callback error: {e}")
return AiModelResponse(content=content, success=True, modelId=self._selected_model.name, metadata={})
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -484,6 +593,11 @@ class AICenterChatModel(BaseChatModel):
"tool_calls": tool_calls, "tool_calls": tool_calls,
}, },
) )
elif not tools and self._selected_model.connectorType in ("openai", "privatellm"):
# Streaming path for OpenAI/Ollama without tools (ChatGPT-like token streaming)
response = await self._call_openai_streaming(
ai_messages, run_manager, model_call, input_bytes, start_time
)
else: else:
# No tools or not OpenAI - use connector normally # No tools or not OpenAI - use connector normally
if not self._selected_model.functionCall: if not self._selected_model.functionCall:

View file

@ -5,6 +5,7 @@ Custom LangGraph checkpointer using existing database interface.
Maps LangGraph state to existing message storage format. Maps LangGraph state to existing message storage format.
""" """
import contextvars
import logging import logging
import uuid import uuid
from typing import Any, Dict, List, Optional, Tuple, NamedTuple from typing import Any, Dict, List, Optional, Tuple, NamedTuple
@ -47,19 +48,30 @@ class DatabaseCheckpointer(BaseCheckpointSaver):
Maps LangGraph thread_id to conversation.id; stores messages via interface (workflowId maps to conversationId). Maps LangGraph thread_id to conversation.id; stores messages via interface (workflowId maps to conversationId).
""" """
def __init__(self, user: User, workflow_id: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None): def __init__(
self,
user: User,
workflow_id: str,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
*,
interface=None,
):
""" """
Initialize the database checkpointer. Initialize the database checkpointer.
Args: Args:
user: Current user for database access user: Current user for database access
workflow_id: Workflow ID (maps to LangGraph thread_id) workflow_id: Workflow ID (maps to LangGraph thread_id)
mandateId: Mandate ID for proper data isolation mandateId: Mandate ID for proper data isolation
featureInstanceId: Feature instance ID for proper data isolation featureInstanceId: Feature instance ID for proper data isolation
interface: Optional pre-created chatbot interface (avoids extra getInterface + DB init)
""" """
self.user = user self.user = user
self.workflow_id = workflow_id self.workflow_id = workflow_id
self.interface = getChatbotInterface(user, mandateId=mandateId, featureInstanceId=featureInstanceId) self.interface = interface if interface is not None else getChatbotInterface(
user, mandateId=mandateId, featureInstanceId=featureInstanceId
)
def _convert_langchain_to_db_message( def _convert_langchain_to_db_message(
self, self,
@ -445,3 +457,120 @@ class DatabaseCheckpointer(BaseCheckpointSaver):
# Not implemented - using aput() instead # Not implemented - using aput() instead
# This method is called by LangGraph but we handle writes through aput() # This method is called by LangGraph but we handle writes through aput()
pass pass
# ContextVar for per-request checkpointer (used by CheckpointerResolver for graph caching)
_current_checkpointer: contextvars.ContextVar[Optional[BaseCheckpointSaver]] = contextvars.ContextVar(
"chatbot_current_checkpointer", default=None
)
def set_checkpointer(checkpointer: BaseCheckpointSaver) -> contextvars.Token:
"""Set the current request's checkpointer. Returns token to reset later."""
return _current_checkpointer.set(checkpointer)
def reset_checkpointer(token: contextvars.Token) -> None:
"""Reset checkpointer to prior value. Safe when called from a different async context."""
try:
_current_checkpointer.reset(token)
except ValueError:
# Token was created in a different context (e.g. after yield, generator cleanup)
pass
class CheckpointerResolver(BaseCheckpointSaver):
"""
Delegating checkpointer that reads the real checkpointer from context.
Used for graph caching: the compiled graph uses this resolver; at invoke time
the per-request checkpointer is set via set_checkpointer().
"""
def _get_checkpointer(self) -> BaseCheckpointSaver:
cp = _current_checkpointer.get()
if cp is None:
raise RuntimeError(
"CheckpointerResolver: no checkpointer in context. "
"Call set_checkpointer() before invoking the cached graph."
)
return cp
def put(
self,
config: Dict[str, Any],
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Dict[str, int],
) -> None:
self._get_checkpointer().put(config, checkpoint, metadata, new_versions)
def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]:
return self._get_checkpointer().get(config)
def list(
self,
config: Dict[str, Any],
filter: Optional[Dict[str, Any]] = None,
before: Optional[str] = None,
limit: Optional[int] = None,
) -> List[Checkpoint]:
return self._get_checkpointer().list(config, filter, before, limit)
def put_writes(
self,
config: Dict[str, Any],
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
self._get_checkpointer().put_writes(config, writes, task_id)
async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]:
inner = self._get_checkpointer()
if hasattr(inner, "aget_tuple"):
return await inner.aget_tuple(config)
checkpoint = inner.get(config)
if checkpoint:
metadata: CheckpointMetadata = {"step": 0}
return CheckpointTuple(
config=config,
checkpoint=checkpoint,
metadata=metadata,
parent_config=None,
pending_writes=None,
)
return None
async def aput(
self,
config: Dict[str, Any],
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Dict[str, int],
) -> None:
inner = self._get_checkpointer()
if hasattr(inner, "aput"):
await inner.aput(config, checkpoint, metadata, new_versions)
else:
inner.put(config, checkpoint, metadata, new_versions)
async def alist(
self,
config: Dict[str, Any],
filter: Optional[Dict[str, Any]] = None,
before: Optional[str] = None,
limit: Optional[int] = None,
) -> List[Checkpoint]:
inner = self._get_checkpointer()
if hasattr(inner, "alist"):
return await inner.alist(config, filter, before, limit)
return inner.list(config, filter, before, limit)
async def aput_writes(
self,
config: Dict[str, Any],
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
inner = self._get_checkpointer()
if hasattr(inner, "aput_writes"):
await inner.aput_writes(config, writes, task_id)

View file

@ -2,8 +2,10 @@
# All rights reserved. # All rights reserved.
"""Chatbot domain logic.""" """Chatbot domain logic."""
import contextvars
import re import re
import logging import logging
import threading
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Annotated, AsyncIterator, Any, List, Optional, TYPE_CHECKING from typing import Annotated, AsyncIterator, Any, List, Optional, TYPE_CHECKING
from pydantic import BaseModel from pydantic import BaseModel
@ -21,7 +23,12 @@ from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from modules.features.chatbot.bridges.ai import AICenterChatModel from modules.features.chatbot.bridges.ai import AICenterChatModel
from modules.features.chatbot.bridges.memory import DatabaseCheckpointer from modules.features.chatbot.bridges.memory import (
CheckpointerResolver,
DatabaseCheckpointer,
set_checkpointer,
reset_checkpointer,
)
from modules.features.chatbot.bridges.tools import ( from modules.features.chatbot.bridges.tools import (
create_sql_query_tool, create_sql_query_tool,
create_tavily_search_tool, create_tavily_search_tool,
@ -168,13 +175,412 @@ class ChatState(BaseModel):
plan: Optional[str] = None # Planner routing: "SQL", "TAVILY", "BOTH", "NONE" plan: Optional[str] = None # Planner routing: "SQL", "TAVILY", "BOTH", "NONE"
@dataclass
class ChatbotGraphContext:
"""Per-request context for cached graph execution. Nodes read model/tools from here."""
model: AICenterChatModel
planner_model: AICenterChatModel
tools: List[Any]
tools_by_name: dict
sql_tool: Any
tavily_tool: Any
streaming_tool: Any
prompt_sections: dict
system_prompt: str
_graph_context: contextvars.ContextVar[Optional[ChatbotGraphContext]] = contextvars.ContextVar(
"chatbot_graph_context", default=None
)
def _get_graph_context() -> ChatbotGraphContext:
ctx = _graph_context.get()
if ctx is None:
raise RuntimeError(
"ChatbotGraphContext not set. Ensure graph context is set before invoking cached graph."
)
return ctx
def _set_graph_context(ctx: ChatbotGraphContext) -> contextvars.Token:
return _graph_context.set(ctx)
def _reset_graph_context(token: contextvars.Token) -> None:
"""Reset graph context. Safe when called from a different async context (e.g. generator cleanup)."""
try:
_graph_context.reset(token)
except ValueError:
# Token was created in a different context (e.g. after yield, generator cleanup)
pass
# Cached compiled graph; lock for thread-safe cache access
_compiled_graph_cache: Optional[CompiledStateGraph] = None
_compiled_graph_lock = threading.Lock()
def _get_or_build_cached_graph() -> CompiledStateGraph:
"""Return cached compiled graph or build and cache it. Thread-safe."""
global _compiled_graph_cache
with _compiled_graph_lock:
if _compiled_graph_cache is not None:
return _compiled_graph_cache
_compiled_graph_cache = _build_cached_graph()
logger.info("Chatbot: compiled graph cached for reuse")
return _compiled_graph_cache
def _build_cached_graph() -> CompiledStateGraph:
"""Build the chatbot graph with context-resolved nodes and CheckpointerResolver."""
checkpointer = CheckpointerResolver()
PLANNER_SYSTEM = (
"Du bist ein Assistent. Antworte NUR mit einem Wort: SQL, TAVILY, BOTH oder NONE.\n"
"SQL = Fragen zu Lager, Bestand, Artikel, Preisen, wie viele, Anzahl (Datenbankabfrage).\n"
"TAVILY = Internetsuche, Produktinfos außerhalb der DB, Markttrends.\n"
"BOTH = beides nötig. NONE = nur Begrüßung oder Danksagung, keine Daten nötig.\n"
"Beispiele: 'wie viele X auf Lager' -> SQL, 'Infos zu Produkt Y' -> TAVILY."
)
SCHEMA_TRUNCATION_SUFFIX = (
"\n\n[... Schema gekürzt. Wichtige Tabellen: Artikel, Lagerplatz_Artikel, Einkaufspreis, Lagerplatz. "
"Artikel-Spalte: a.\"Artikelbezeichnung\". "
"JOIN: Artikel a, Lagerplatz_Artikel l ON a.I_ID = l.R_ARTIKEL, Lagerplatz lp ON l.R_LAGERPLATZ = lp.I_ID.]"
)
SQL_PLAN_SUFFIX = (
"\n\n--- AUSGABEFORMAT (PFLICHT) ---\n"
"Antworte NUR mit einer SQL SELECT-Abfrage in diesem Format:\n"
"```sql\nDEINE_SQL_QUERY\n```\n"
"KRITISCH bei 'wie viele X auf Lager': Liefere ARTIKELZEILEN (Artikelnummer, Artikelbezeichnung, Bestand) "
"mit LIMIT 20, NICHT nur SELECT COUNT(*). Ohne Detailzeilen kann keine Tabelle angezeigt werden. "
"Gesamtanzahl optional via Unterabfrage im SELECT."
)
FORMULATE_TASK = (
"\n\n--- AKTUELLE AUFGABE ---\n"
"Du erhältst eine Benutzerfrage und die exakten Datenbankergebnisse. "
"KRITISCH: Nutze NUR die gelieferten Daten. Erfinde NIEMALS Daten (keine LED-A01, LED Rot, etc.). "
"Wenn die Ergebnisse NUR eine Zahl enthalten (z.B. '1. COUNT(*): 806'): Reportiere NUR diese Zahl, KEINE erfundene Tabelle. "
"Eine Tabelle darf NUR erstellt werden, wenn echte Zeilen '1. Spalte: Wert, ...' in den Daten stehen. "
"Beachte die obige ANTWORTSTRUKTUR."
)
bytes_per_token = 3
reserved_tokens = 3000
_SQL_KEYWORDS = (
"lager", "bestand", "artikel", "wie viele", "anzahl", "preis",
"lieferant", "lieferanten", "bestellen", "verfügbar", "inventar"
)
def _get_context_length(ctx: ChatbotGraphContext) -> int:
if hasattr(ctx.model, "_selected_model") and ctx.model._selected_model:
return getattr(ctx.model._selected_model, "contextLength", 128000)
return 128000
def _truncate_system_prompt(full_prompt: str, max_chars: int, suffix: str = "") -> str:
if len(full_prompt) <= max_chars:
return full_prompt
return full_prompt[: max_chars - len(suffix)] + suffix
async def planner_node(state: ChatState) -> dict:
ctx = _get_graph_context()
human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)]
last_human = human_msgs[-1].content if human_msgs else ""
window = [SystemMessage(content=PLANNER_SYSTEM), HumanMessage(content=last_human)]
plan = "SQL"
try:
response = await ctx.planner_model.ainvoke(window)
except ValueError as exc:
if "No suitable model found" in str(exc):
logger.warning(f"Planner model selection failed: {exc}")
return {"plan": plan}
raise
content = (response.content or "").strip().upper()
for keyword in ("SQL", "TAVILY", "BOTH", "NONE"):
if keyword in content:
plan = keyword
break
return {"plan": plan}
def route_by_plan(state: ChatState) -> str:
ctx = _get_graph_context()
plan = (state.plan or "SQL").upper()
if plan == "NONE" and ctx.sql_tool:
last_user = ""
for m in reversed(state.messages):
if isinstance(m, HumanMessage):
last_user = (m.content or "").lower()
break
if any(kw in last_user for kw in _SQL_KEYWORDS):
logger.info("Planner returned NONE but user asked inventory question - routing to SQL")
plan = "SQL"
if plan in ("SQL", "BOTH") and ctx.sql_tool:
return "agent_sql_plan"
if plan == "TAVILY" and ctx.tavily_tool:
return "agent_tavily"
return "agent_answer"
def select_window(ctx: ChatbotGraphContext, msgs: List[BaseMessage], max_tokens_override: Optional[int] = None) -> List[BaseMessage]:
def approx_counter(items: List[BaseMessage]) -> int:
return sum(len(getattr(m, "content", "") or "") for m in items)
max_tokens = max_tokens_override or _get_context_length(ctx)
return trim_messages(
msgs,
strategy="last",
token_counter=approx_counter,
max_tokens=int(max_tokens * 0.8),
start_on="human",
end_on=("human", "tool"),
include_system=True,
)
async def _agent_common(state: ChatState, system_content: str, llm: Any, node_name: str) -> dict:
ctx = _get_graph_context()
msgs = select_window(ctx, state.messages)
if not msgs or not isinstance(msgs[0], SystemMessage):
window = [SystemMessage(content=system_content)] + msgs
else:
window = [SystemMessage(content=system_content)] + [m for m in msgs if not isinstance(m, SystemMessage)]
try:
response = await llm.ainvoke(window)
except ValueError as exc:
if "No suitable model found" in str(exc):
logger.warning(f"{node_name} model selection failed: {exc}")
response = AIMessage(
content="Es tut mir leid, derzeit steht kein passendes KI-Modell für diese Anfrage zur Verfügung. "
"Bitte versuchen Sie es später erneut oder wenden Sie sich an den Administrator."
)
else:
raise
return {"messages": [response]}
def _parse_sql_from_content(content: str) -> Optional[str]:
if not content:
return None
match = re.search(r"```(?:sql)?\s*([\s\S]*?)```", content)
if match:
sql = match.group(1).strip()
if sql and sql.upper().strip().startswith("SELECT"):
return sql
for line in content.split("\n"):
line = line.strip()
if line.upper().startswith("SELECT"):
return line
return None
def _sanitize_sql_typos(sql: str) -> str:
if not sql:
return sql
sql = re.sub(r"WHEN([A-Za-z_][A-Za-z0-9_.\"]*)", r"WHEN \1", sql, flags=re.IGNORECASE)
sql = re.sub(r"\bLAGerplatz_Artikel\b", "Lagerplatz_Artikel", sql)
sql = re.sub(r"\bLAGerplatz\b", "Lagerplatz", sql)
sql = sql.replace('"Einkaufspreis_neu"', '"Einkaufspreis"')
sql = sql.replace("Einkaufspreis_neu.", "Einkaufspreis.")
sql = re.sub(r'"Einkaufspreis"\."ARTIKEL"', '"Einkaufspreis"."m_Artikel"', sql)
return sql
async def agent_sql_plan_node(state: ChatState) -> dict:
ctx = _get_graph_context()
ctx_len = _get_context_length(ctx)
max_system_chars = max(1000, int(ctx_len * 0.8 - reserved_tokens) * bytes_per_token) - len(SQL_PLAN_SUFFIX)
schema_part = ctx.prompt_sections.get("schema") or ctx.prompt_sections.get("intro", "")
intro_part = (ctx.prompt_sections.get("intro", "") or "")[:400]
combined = f"{intro_part}\n\n{schema_part}" if intro_part else schema_part
system_content = _truncate_system_prompt(combined, max_system_chars, SCHEMA_TRUNCATION_SUFFIX) + SQL_PLAN_SUFFIX
llm = ctx.model
return await _agent_common(state, system_content, llm, "agent_sql_plan")
async def parse_execute_sql_node(state: ChatState) -> dict:
ctx = _get_graph_context()
sql_t = ctx.sql_tool
last_msg = state.messages[-1] if state.messages else None
if not isinstance(last_msg, AIMessage):
return {"messages": [ToolMessage(content="Fehler: Keine AI-Antwort zum Parsen.", tool_call_id="parse_0", name="sqlite_query")]}
sql = _parse_sql_from_content(last_msg.content or "")
if not sql or not sql_t:
return {"messages": [ToolMessage(content="Konnte keine SQL-Abfrage aus der Antwort extrahieren.", tool_call_id="parse_0", name="sqlite_query")]}
sql = _sanitize_sql_typos(sql)
try:
result = await sql_t.ainvoke({"query": sql})
except Exception as e:
logger.error(f"SQL execution failed: {e}")
result = f"Fehler bei der Ausführung: {e}"
return {"messages": [ToolMessage(content=str(result), tool_call_id="parse_0", name="sqlite_query")]}
async def agent_formulate_node(state: ChatState) -> dict:
ctx = _get_graph_context()
human_content = ""
tool_content = ""
for m in state.messages:
if isinstance(m, HumanMessage):
human_content = m.content or ""
if isinstance(m, ToolMessage) and getattr(m, "name", "") == "sqlite_query":
tool_content = m.content or ""
if not tool_content or not tool_content.strip():
logger.warning("agent_formulate: no tool_content (sqlite_query) in state.messages")
return {"messages": [AIMessage(content="Die Datenbankabfrage konnte keine Ergebnisse liefern. Bitte versuchen Sie es mit einer anderen Formulierung.")]}
if "Query failed" in tool_content or tool_content.strip().startswith("Error"):
err_summary = "Die Datenbankabfrage ist fehlgeschlagen."
if "no such column" in tool_content:
err_summary += " Ein Spaltenname scheint nicht zu passen. Bitte die Anfrage anders formulieren."
return {"messages": [AIMessage(content=err_summary)]}
formatted_data = _tool_output_to_markdown_table(tool_content)
logger.debug(f"agent_formulate: tool_content length={len(tool_content)}, formatted={len(formatted_data)}")
ctx_len = _get_context_length(ctx)
max_system_chars = max(3000, int(ctx_len * 0.5) * bytes_per_token) - len(FORMULATE_TASK)
resp_struct = ctx.prompt_sections.get("response_structure") or ctx.prompt_sections.get("intro", "")
intro_formulate = ctx.prompt_sections.get("intro", "")
combined = f"{intro_formulate}\n\n{resp_struct}" if intro_formulate != resp_struct else resp_struct
if len(combined) + len(FORMULATE_TASK) > max_system_chars:
combined = _truncate_system_prompt(combined, max_system_chars - len(FORMULATE_TASK), "")
system_content = combined + FORMULATE_TASK
prompt = (
f"Benutzerfrage: {human_content}\n\n"
"--- VORGEGEBENE DATEN (diese Tabelle/Zahlen UNVERÄNDERT in die Antwort übernehmen): ---\n"
f"{formatted_data}\n\n"
"Die obige Tabelle bzw. Zahlen sind die EINZIGEN erlaubten Daten. Kopiere sie 1:1. "
"Berechne keine eigenen Summen/Anzahlen - nutze die gelieferten Werte. Formuliere die Antwort:"
)
window = [SystemMessage(content=system_content), HumanMessage(content=prompt)]
try:
response = await ctx.model.ainvoke(window)
except ValueError as exc:
if "No suitable model found" in str(exc):
response = AIMessage(content="Es gab einen Fehler bei der Formulierung. Bitte versuchen Sie es erneut.")
else:
raise
if response.content:
response = AIMessage(content=_sanitize_llm_response(response.content))
return {"messages": [response]}
async def agent_tavily_node(state: ChatState) -> dict:
ctx = _get_graph_context()
resp_struct = ctx.prompt_sections.get("response_structure") or ""
intro_tavily = ctx.prompt_sections.get("intro", "")
combined = f"{intro_tavily}\n\n{resp_struct}" if resp_struct else intro_tavily
system_content = _truncate_system_prompt(combined, 6000, "")
tools_tavily = [t for t in [ctx.tavily_tool, ctx.streaming_tool] if t is not None]
llm_tavily = ctx.model.bind_tools(tools=tools_tavily) if tools_tavily else ctx.model
return await _agent_common(state, system_content, llm_tavily, "agent_tavily")
async def agent_answer_node(state: ChatState) -> dict:
ctx = _get_graph_context()
resp_struct = ctx.prompt_sections.get("response_structure") or ""
intro_answer = ctx.prompt_sections.get("intro", "")
combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer
system_content = _truncate_system_prompt(combined, 6000, "")
llm = ctx.planner_model if ctx.planner_model else ctx.model
return await _agent_common(state, system_content, llm, "agent_answer")
def should_continue_tavily(state: ChatState) -> str:
last = state.messages[-1]
return "tools" if getattr(last, "tool_calls", None) else END
def route_back(state: ChatState) -> str:
ctx = _get_graph_context()
return "agent_tavily" if ctx.tavily_tool else "agent_answer"
async def tools_with_retry(state: ChatState) -> dict:
import asyncio
ctx = _get_graph_context()
last_message = state.messages[-1]
tool_calls = getattr(last_message, "tool_calls", [])
if not tool_calls:
return {"messages": []}
tools_by_name = ctx.tools_by_name
async def execute_single_tool(tool_call):
tool_name = tool_call.get("name") or tool_call.get("function", {}).get("name")
tool_id = tool_call.get("id", f"call_{tool_name}")
args = tool_call.get("args") or tool_call.get("function", {}).get("arguments", {})
if isinstance(args, str):
import json
try:
args = json.loads(args)
except Exception:
args = {"input": args}
tool = tools_by_name.get(tool_name)
if not tool:
return ToolMessage(content=f"Error: Tool '{tool_name}' not found", tool_call_id=tool_id, name=tool_name)
try:
if hasattr(tool, "coroutine") and asyncio.iscoroutinefunction(tool.coroutine):
result = await tool.coroutine(**args)
elif hasattr(tool, "ainvoke"):
result = await tool.ainvoke(args)
else:
result = tool.invoke(args)
return ToolMessage(content=str(result), tool_call_id=tool_id, name=tool_name)
except Exception as e:
logger.error(f"Tool {tool_name} failed: {e}")
return ToolMessage(content=f"Error executing {tool_name}: {str(e)}", tool_call_id=tool_id, name=tool_name)
tool_messages = await asyncio.gather(
*[execute_single_tool(tc) for tc in tool_calls],
return_exceptions=True
)
result_messages = []
for i, msg in enumerate(tool_messages):
if isinstance(msg, Exception):
tool_call = tool_calls[i]
tool_name = tool_call.get("name", "unknown")
tool_id = tool_call.get("id", f"call_{i}")
result_messages.append(ToolMessage(content=f"Error: {str(msg)}", tool_call_id=tool_id, name=tool_name))
else:
result_messages.append(msg)
result = {"messages": result_messages}
no_results_keywords = [
"returned 0 rows", "no data", "keine artikel gefunden", "keine ergebnisse"
]
for msg in result.get("messages", []):
content = getattr(msg, "content", "")
if isinstance(content, str):
content_lower = content.lower()
if any(keyword in content_lower for keyword in no_results_keywords):
retry_count = sum(1 for m in state.messages if "retry" in str(getattr(m, "content", "")).lower())
if retry_count < 2:
logger.info("No results found in tool output, adding retry instruction")
retry_message = HumanMessage(
content="WICHTIG: Die vorherige Suche hat keine Ergebnisse gefunden. "
"Bitte versuche eine alternative Suchstrategie:\n"
"1. Wenn die Frage im Format 'X von Y' war (z.B. 'Lampen von Eaton'), "
"verwende IMMER eine Kombination aus Lieferanten-Filter (WHERE a.\"Lieferant\" LIKE '%Y%') "
"UND Produkttyp-Filter (WHERE a.\"Artikelbezeichnung\" LIKE '%X%' OR ...)\n"
"2. Verwende mehrere Synonyme für den Produkttyp (z.B. bei 'Lampen': Lampe, LED, Beleuchtung, Licht, Leuchte, Strahler)\n"
"3. Führe zuerst eine COUNT-Abfrage durch, dann die Detail-Abfrage mit Lagerbeständen\n"
"4. Verwende LIKE '%Lieferant%' für den Lieferanten-Filter, um auch Varianten zu finden"
)
result["messages"].append(retry_message)
break
return result
workflow = StateGraph(ChatState)
workflow.add_node("planner", planner_node)
workflow.add_node("agent_sql_plan", agent_sql_plan_node)
workflow.add_node("parse_execute_sql", parse_execute_sql_node)
workflow.add_node("agent_formulate", agent_formulate_node)
workflow.add_node("tools", tools_with_retry)
workflow.add_node("agent_tavily", agent_tavily_node)
workflow.add_node("agent_answer", agent_answer_node)
workflow.add_edge(START, "planner")
workflow.add_conditional_edges("planner", route_by_plan)
workflow.add_edge("agent_sql_plan", "parse_execute_sql")
workflow.add_edge("parse_execute_sql", "agent_formulate")
workflow.add_edge("agent_formulate", END)
workflow.add_conditional_edges("agent_tavily", should_continue_tavily)
workflow.add_edge("agent_answer", END)
workflow.add_conditional_edges("tools", route_back)
return workflow.compile(checkpointer=checkpointer)
@dataclass @dataclass
class Chatbot: class Chatbot:
"""Represents a chatbot.""" """Represents a chatbot."""
model: AICenterChatModel model: AICenterChatModel
memory: DatabaseCheckpointer memory: DatabaseCheckpointer
planner_model: Optional[AICenterChatModel] = None # Fast model for routing (SQL/TAVILY/NONE)
app: CompiledStateGraph = None app: CompiledStateGraph = None
_tools: List[Any] = field(default_factory=list) # Configured tools (for cached graph context)
system_prompt: str = "You are a helpful assistant." system_prompt: str = "You are a helpful assistant."
workflow_id: str = "default" workflow_id: str = "default"
config: Optional["ChatbotConfig"] = None config: Optional["ChatbotConfig"] = None
@ -189,6 +595,7 @@ class Chatbot:
workflow_id: str = "default", workflow_id: str = "default",
config: Optional["ChatbotConfig"] = None, config: Optional["ChatbotConfig"] = None,
event_manager=None, event_manager=None,
planner_model: Optional[AICenterChatModel] = None,
) -> "Chatbot": ) -> "Chatbot":
"""Factory method to create and configure a Chatbot instance. """Factory method to create and configure a Chatbot instance.
@ -199,6 +606,7 @@ class Chatbot:
workflow_id: The workflow ID (maps to thread_id). workflow_id: The workflow ID (maps to thread_id).
config: Optional chatbot configuration for dynamic tool enablement. config: Optional chatbot configuration for dynamic tool enablement.
event_manager: Optional event manager for streaming (passed from route). event_manager: Optional event manager for streaming (passed from route).
planner_model: Optional fast model for planner/routing (default: same as model).
Returns: Returns:
A configured Chatbot instance. A configured Chatbot instance.
@ -210,9 +618,11 @@ class Chatbot:
workflow_id=workflow_id, workflow_id=workflow_id,
config=config, config=config,
_event_manager=event_manager, _event_manager=event_manager,
planner_model=planner_model,
) )
configured_tools = await instance._configure_tools() configured_tools = await instance._configure_tools()
instance.app = instance._build_app(memory, configured_tools) instance._tools = configured_tools
instance.app = _get_or_build_cached_graph()
return instance return instance
async def _configure_tools(self) -> List[Any]: async def _configure_tools(self) -> List[Any]:
@ -281,6 +691,7 @@ class Chatbot:
tools_sql = [t for t in [sql_tool, tavily_tool, streaming_tool] if t is not None] tools_sql = [t for t in [sql_tool, tavily_tool, streaming_tool] if t is not None]
tools_tavily = [t for t in [tavily_tool, streaming_tool] if t is not None] tools_tavily = [t for t in [tavily_tool, streaming_tool] if t is not None]
llm_plain = self.model llm_plain = self.model
llm_planner = self.planner_model if self.planner_model else self.model
# SQL path uses structured prompts + parse/execute (no native tool calling) - fits /api/analyze # SQL path uses structured prompts + parse/execute (no native tool calling) - fits /api/analyze
llm_tavily = self.model.bind_tools(tools=tools_tavily) if tools_tavily else self.model llm_tavily = self.model.bind_tools(tools=tools_tavily) if tools_tavily else self.model
@ -348,7 +759,8 @@ class Chatbot:
async def planner_node(state: ChatState) -> dict: async def planner_node(state: ChatState) -> dict:
"""Planner: minimal prompt, no tools. Outputs SQL/TAVILY/BOTH/NONE. """Planner: minimal prompt, no tools. Outputs SQL/TAVILY/BOTH/NONE.
Does NOT add planner message to chat - only sets state.plan for routing.""" Does NOT add planner message to chat - only sets state.plan for routing.
Uses llm_planner (fast model) when available for lower latency."""
human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)] human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)]
last_human = human_msgs[-1].content if human_msgs else "" last_human = human_msgs[-1].content if human_msgs else ""
window = [ window = [
@ -357,7 +769,7 @@ class Chatbot:
] ]
plan = "SQL" plan = "SQL"
try: try:
response = await llm_plain.ainvoke(window) response = await llm_planner.ainvoke(window)
except ValueError as exc: except ValueError as exc:
if "No suitable model found" in str(exc): if "No suitable model found" in str(exc):
logger.warning(f"Planner model selection failed: {exc}") logger.warning(f"Planner model selection failed: {exc}")
@ -555,12 +967,12 @@ class Chatbot:
return await _agent_common(state, system_content, llm_tavily, "agent_tavily") return await _agent_common(state, system_content, llm_tavily, "agent_tavily")
async def agent_answer_node(state: ChatState) -> dict: async def agent_answer_node(state: ChatState) -> dict:
"""Agent with no tools. Uses intro + response_structure.""" """Agent with no tools (plan NONE). Uses fast model for lower latency."""
resp_struct = _prompt_sections["response_structure"] or "" resp_struct = _prompt_sections["response_structure"] or ""
intro_answer = _prompt_sections["intro"] intro_answer = _prompt_sections["intro"]
combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer
system_content = _truncate_system_prompt(combined, 6000, "") system_content = _truncate_system_prompt(combined, 6000, "")
return await _agent_common(state, system_content, llm_plain, "agent_answer") return await _agent_common(state, system_content, llm_planner, "agent_answer")
def should_continue_tavily(state: ChatState) -> str: def should_continue_tavily(state: ChatState) -> str:
last = state.messages[-1] last = state.messages[-1]
@ -722,16 +1134,29 @@ class Chatbot:
Returns: Returns:
The list of messages in the chat history. The list of messages in the chat history.
""" """
# Set the right thread ID for memory
config = {"configurable": {"thread_id": chat_id}} config = {"configurable": {"thread_id": chat_id}}
tools_by_name = {t.name: t for t in self._tools}
# Single-turn chat (non-streaming) graph_ctx = ChatbotGraphContext(
result = await self.app.ainvoke( model=self.model,
{"messages": [HumanMessage(content=message)]}, config=config planner_model=self.planner_model or self.model,
tools=self._tools,
tools_by_name=tools_by_name,
sql_tool=tools_by_name.get("sqlite_query"),
tavily_tool=tools_by_name.get("tavily_search"),
streaming_tool=tools_by_name.get("send_streaming_message"),
prompt_sections=_split_system_prompt(self.system_prompt),
system_prompt=self.system_prompt,
) )
ctx_token = _set_graph_context(graph_ctx)
# Extract and return the messages from the result cp_token = set_checkpointer(self.memory)
return result["messages"] try:
result = await self.app.ainvoke(
{"messages": [HumanMessage(content=message)]}, config=config
)
return result["messages"]
finally:
_reset_graph_context(ctx_token)
reset_checkpointer(cp_token)
async def stream_events( async def stream_events(
self, *, message: str, chat_id: str = "default" self, *, message: str, chat_id: str = "default"
@ -757,6 +1182,21 @@ class Chatbot:
"""Return True if the event is from the root run (v2: empty parent_ids).""" """Return True if the event is from the root run (v2: empty parent_ids)."""
return not ev.get("parent_ids") return not ev.get("parent_ids")
# Build tool lookup for cached graph context
tools_by_name = {t.name: t for t in self._tools}
graph_ctx = ChatbotGraphContext(
model=self.model,
planner_model=self.planner_model or self.model,
tools=self._tools,
tools_by_name=tools_by_name,
sql_tool=tools_by_name.get("sqlite_query"),
tavily_tool=tools_by_name.get("tavily_search"),
streaming_tool=tools_by_name.get("send_streaming_message"),
prompt_sections=_split_system_prompt(self.system_prompt),
system_prompt=self.system_prompt,
)
ctx_token = _set_graph_context(graph_ctx)
cp_token = set_checkpointer(self.memory)
try: try:
async for event in self.app.astream_events( async for event in self.app.astream_events(
{"messages": [HumanMessage(content=message)]}, {"messages": [HumanMessage(content=message)]},
@ -767,6 +1207,25 @@ class Chatbot:
ename = event.get("name") or "" ename = event.get("name") or ""
edata = event.get("data") or {} edata = event.get("data") or {}
# Stream LLM tokens for ChatGPT-like incremental display
if etype in ("on_llm_stream", "on_chat_model_stream"):
ch = edata.get("chunk")
if ch is None:
continue
# Chunk can be string, AIMessageChunk (has .content), or dict
content = ""
if isinstance(ch, str):
content = ch
elif hasattr(ch, "content"):
content = ch.content or ""
if isinstance(content, list):
content = "".join(str(x) for x in content)
elif isinstance(ch, dict):
content = ch.get("content", "") or ""
if isinstance(content, str) and content:
yield {"type": "chunk", "content": content}
continue
# Stream human-readable progress via the special send_streaming_message tool # Stream human-readable progress via the special send_streaming_message tool
# Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py) # Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py)
if etype == "on_tool_start": if etype == "on_tool_start":
@ -833,3 +1292,6 @@ class Chatbot:
# Emit a single error envelope and end the stream # Emit a single error envelope and end the stream
logger.error(f"Exception in stream_events: {exc}", exc_info=True) logger.error(f"Exception in stream_events: {exc}", exc_info=True)
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"} yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}
finally:
_reset_graph_context(ctx_token)
reset_checkpointer(cp_token)

View file

@ -397,8 +397,8 @@ class ChatObjects:
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
# Create database connector directly from modules.connectors.connectorDbPostgre import _get_cached_connector
self.db = DatabaseConnector( self.db = _get_cached_connector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,
@ -769,6 +769,72 @@ class ChatObjects:
"""Backward-compat alias: workflowId maps to conversationId.""" """Backward-compat alias: workflowId maps to conversationId."""
return self.getConversation(workflowId) return self.getConversation(workflowId)
def getWorkflowMinimal(self, workflowId: str) -> Optional[ChatbotConversation]:
"""Lightweight fetch: conversation record only, no logs/messages. For resume path."""
conversations = getRecordsetWithRBAC(
self.db,
ChatbotConversation,
self.currentUser,
recordFilter={"id": workflowId},
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode="chatbot",
)
if not conversations:
return None
conv = conversations[0]
max_steps = conv.get("maxSteps")
return ChatbotConversation(
id=conv["id"],
featureInstanceId=conv.get("featureInstanceId") or self.featureInstanceId or "",
name=conv.get("name"),
status=conv.get("status", "running"),
currentRound=conv.get("currentRound", 0) or 0,
lastActivity=conv.get("lastActivity", getUtcTimestamp()),
startedAt=conv.get("startedAt", getUtcTimestamp()),
workflowMode=ChatbotWorkflowModeEnum(conv.get("workflowMode", "Chatbot")),
maxSteps=max_steps if max_steps is not None else 10,
logs=[],
messages=[],
)
def getMessageCount(self, conversationId: str) -> int:
"""Returns message count for a conversation (single query, no document fetch)."""
messages = getRecordsetWithRBAC(
self.db,
ChatbotMessage,
self.currentUser,
recordFilter={"conversationId": conversationId},
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode="chatbot",
)
return len(messages) if messages else 0
def updateWorkflowMinimal(
self, workflowId: str, workflowData: Dict[str, Any]
) -> ChatbotConversation:
"""Lightweight update: no logs/messages fetch. For resume when caller has minimal workflow."""
if not self.checkRbacPermission(ChatbotConversation, "update", workflowId):
raise PermissionError(f"No permission to update conversation {workflowId}")
simpleFields, _ = self._separateObjectFields(ChatbotConversation, workflowData)
simpleFields["lastActivity"] = getUtcTimestamp()
updated = self.db.recordModify(ChatbotConversation, workflowId, simpleFields)
max_steps = updated.get("maxSteps")
return ChatbotConversation(
id=updated["id"],
featureInstanceId=updated.get("featureInstanceId") or self.featureInstanceId or "",
name=updated.get("name"),
status=updated.get("status", "running"),
currentRound=updated.get("currentRound", 0) or 0,
lastActivity=updated.get("lastActivity", getUtcTimestamp()),
startedAt=updated.get("startedAt", getUtcTimestamp()),
workflowMode=ChatbotWorkflowModeEnum(updated.get("workflowMode", "Chatbot")),
maxSteps=max_steps if max_steps is not None else 10,
logs=[],
messages=[],
)
def createConversation(self, conversationData: Dict[str, Any]) -> ChatbotConversation: def createConversation(self, conversationData: Dict[str, Any]) -> ChatbotConversation:
"""Creates a new conversation if user has permission.""" """Creates a new conversation if user has permission."""
if not self.checkRbacPermission(ChatbotConversation, "create"): if not self.checkRbacPermission(ChatbotConversation, "create"):
@ -825,9 +891,7 @@ class ChatObjects:
updated = self.db.recordModify(ChatbotConversation, conversationId, simpleFields) updated = self.db.recordModify(ChatbotConversation, conversationId, simpleFields)
logs = self.getLogs(conversationId) # Reuse logs/messages from conv — update only touches simple fields, not related data
messages = self.getMessages(conversationId)
return ChatbotConversation( return ChatbotConversation(
id=updated["id"], id=updated["id"],
featureInstanceId=updated.get("featureInstanceId") or conv.featureInstanceId or self.featureInstanceId or "", featureInstanceId=updated.get("featureInstanceId") or conv.featureInstanceId or self.featureInstanceId or "",
@ -838,8 +902,8 @@ class ChatObjects:
startedAt=updated.get("startedAt", conv.startedAt), startedAt=updated.get("startedAt", conv.startedAt),
workflowMode=ChatbotWorkflowModeEnum(updated.get("workflowMode", conv.workflowMode.value)), workflowMode=ChatbotWorkflowModeEnum(updated.get("workflowMode", conv.workflowMode.value)),
maxSteps=updated.get("maxSteps") if updated.get("maxSteps") is not None else conv.maxSteps, maxSteps=updated.get("maxSteps") if updated.get("maxSteps") is not None else conv.maxSteps,
logs=logs, logs=conv.logs,
messages=messages messages=conv.messages
) )
def updateWorkflow(self, workflowId: str, workflowData: Dict[str, Any]) -> ChatbotConversation: def updateWorkflow(self, workflowId: str, workflowData: Dict[str, Any]) -> ChatbotConversation:
@ -955,11 +1019,13 @@ class ChatObjects:
if pagination and pagination.sort: if pagination and pagination.sort:
messageDicts = self._applySorting(messageDicts, pagination.sort) messageDicts = self._applySorting(messageDicts, pagination.sort)
# If no pagination requested, return all items # If no pagination requested, return all items (batch-fetch documents to avoid N+1)
if pagination is None: if pagination is None:
msg_ids = [m["id"] for m in messageDicts]
docs_by_message = self.getDocumentsForMessages(msg_ids) if msg_ids else {}
chat_messages = [] chat_messages = []
for msg in messageDicts: for msg in messageDicts:
documents = self.getDocuments(msg["id"]) documents = docs_by_message.get(msg["id"], [])
chat_message = ChatbotMessage( chat_message = ChatbotMessage(
id=msg["id"], id=msg["id"],
conversationId=msg["conversationId"], conversationId=msg["conversationId"],
@ -994,10 +1060,11 @@ class ChatObjects:
startIdx = (pagination.page - 1) * pagination.pageSize startIdx = (pagination.page - 1) * pagination.pageSize
endIdx = startIdx + pagination.pageSize endIdx = startIdx + pagination.pageSize
pagedMessageDicts = messageDicts[startIdx:endIdx] pagedMessageDicts = messageDicts[startIdx:endIdx]
paged_msg_ids = [m["id"] for m in pagedMessageDicts]
docs_by_message = self.getDocumentsForMessages(paged_msg_ids) if paged_msg_ids else {}
chat_messages = [] chat_messages = []
for msg in pagedMessageDicts: for msg in pagedMessageDicts:
documents = self.getDocuments(msg["id"]) documents = docs_by_message.get(msg["id"], [])
chat_message = ChatbotMessage( chat_message = ChatbotMessage(
id=msg["id"], id=msg["id"],
conversationId=msg["conversationId"], conversationId=msg["conversationId"],
@ -1224,6 +1291,17 @@ class ChatObjects:
logger.error(f"Error updating message {messageId}: {str(e)}", exc_info=True) logger.error(f"Error updating message {messageId}: {str(e)}", exc_info=True)
raise ValueError(f"Error updating message {messageId}: {str(e)}") raise ValueError(f"Error updating message {messageId}: {str(e)}")
def createStat(self, statData: Dict[str, Any]):
"""Create stat record. Compatibility with ChatService; stats may not be persisted in chatbot schema."""
from modules.datamodels.datamodelChat import ChatStat
stat = ChatStat(**statData)
try:
created = self.db.recordCreate(ChatStat, statData)
return ChatStat(**created)
except Exception as e:
logger.debug(f"createStat: not persisting (chatbot schema): {e}")
return stat
def deleteMessage(self, conversationId: str, messageId: str) -> bool: def deleteMessage(self, conversationId: str, messageId: str) -> bool:
"""Deletes a conversation message and related data if user has access.""" """Deletes a conversation message and related data if user has access."""
try: try:
@ -1308,6 +1386,30 @@ class ChatObjects:
logger.error(f"Error getting message documents: {str(e)}") logger.error(f"Error getting message documents: {str(e)}")
return [] return []
def getDocumentsForMessages(self, messageIds: List[str]) -> Dict[str, List[ChatbotDocument]]:
"""Returns documents for multiple messages in one query. Returns {messageId: [doc, ...]}."""
if not messageIds:
return {}
try:
documents = getRecordsetWithRBAC(
self.db,
ChatbotDocument,
self.currentUser,
recordFilter={"messageId": messageIds},
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode="chatbot",
)
result: Dict[str, List[ChatbotDocument]] = {mid: [] for mid in messageIds}
for doc in documents:
mid = doc.get("messageId")
if mid in result:
result[mid].append(ChatbotDocument(**doc))
return result
except Exception as e:
logger.error(f"Error getting documents for messages: {e}")
return {mid: [] for mid in messageIds}
def createDocument(self, documentData: Dict[str, Any]) -> ChatbotDocument: def createDocument(self, documentData: Dict[str, Any]) -> ChatbotDocument:
"""Creates a document for a message in normalized table.""" """Creates a document for a message in normalized table."""
try: try:
@ -1451,11 +1553,14 @@ class ChatObjects:
items = [] items = []
messages = getRecordsetWithRBAC(self.db, ChatbotMessage, self.currentUser, recordFilter={"conversationId": conversationId}, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode="chatbot") messages = getRecordsetWithRBAC(self.db, ChatbotMessage, self.currentUser, recordFilter={"conversationId": conversationId}, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode="chatbot")
# Batch-fetch documents for all messages (avoids N+1)
message_ids = [m["id"] for m in messages if afterTimestamp is None or parseTimestamp(m.get("publishedAt"), default=getUtcTimestamp()) > afterTimestamp]
docs_by_message = self.getDocumentsForMessages(message_ids) if message_ids else {}
for msg in messages: for msg in messages:
msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp()) msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp())
if afterTimestamp is not None and msgTimestamp <= afterTimestamp: if afterTimestamp is not None and msgTimestamp <= afterTimestamp:
continue continue
documents = self.getDocuments(msg["id"]) documents = docs_by_message.get(msg["id"], [])
chatMessage = ChatbotMessage( chatMessage = ChatbotMessage(
id=msg["id"], id=msg["id"],
conversationId=msg["conversationId"], conversationId=msg["conversationId"],

View file

@ -6,7 +6,7 @@ Handles feature initialization and RBAC catalog registration.
""" """
import logging import logging
from typing import Dict, List, Any from typing import Dict, List, Any, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,6 +48,14 @@ RESOURCE_OBJECTS = [
}, },
] ]
# Service requirements for chatbot — resolved via service center
REQUIRED_SERVICES = [
{"serviceKey": "chat", "meta": {"usage": "File info, document handling"}},
{"serviceKey": "ai", "meta": {"usage": "AI calls, conversation name generation"}},
{"serviceKey": "billing", "meta": {"usage": "Usage tracking, balance checks"}},
{"serviceKey": "streaming", "meta": {"usage": "Event manager, ChatStreamingHelper"}},
]
# Template roles for this feature # Template roles for this feature
# Role names MUST follow convention: {featureCode}-{roleName} # Role names MUST follow convention: {featureCode}-{roleName}
TEMPLATE_ROLES = [ TEMPLATE_ROLES = [
@ -170,6 +178,76 @@ def registerFeature(catalogService) -> bool:
return False return False
def getChatbotServices(
user,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
workflow=None,
) -> "_ChatbotServiceHub":
"""
Get lightweight service hub for chatbot (chat, ai, streaming) without loading
the full legacy Services hub. Avoids ~90 ms from _loadFeatureInterfaces +
_loadFeatureServices; only instantiates required services.
Uses interfaceFeatureChatbot (ChatObjects) for interfaceDbChat to avoid
duplicate DB init - chatProcess reuses hub.interfaceDbChat.
"""
from modules.services import PublicService
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
from modules.services.serviceChat.mainServiceChat import ChatService
from modules.services.serviceAi.mainServiceAi import AiService
from modules.services.serviceStreaming.mainServiceStreaming import StreamingService
hub = _ChatbotServiceHub()
hub.user = user
hub.mandateId = mandateId
hub.featureInstanceId = featureInstanceId
hub.workflow = workflow
hub.featureCode = "chatbot"
hub.allowedProviders = None
hub.interfaceDbApp = getAppInterface(user, mandateId=mandateId)
# interfaceDbComponent: lazy-loaded on first access (saves ~100300 ms when no file uploads)
hub._interfaceDbComponent_val = None
# Use ChatObjects (interfaceFeatureChatbot) - same as chatProcess, avoids extra interfaceDbChat init
hub.interfaceDbChat = getChatbotInterface(
user, mandateId=mandateId, featureInstanceId=featureInstanceId
)
hub.chat = PublicService(ChatService(hub))
hub.ai = PublicService(AiService(hub), functionsOnly=False)
hub.streaming = PublicService(StreamingService(hub))
return hub
class _ChatbotServiceHub:
"""Lightweight hub with chat, ai, streaming for chatbot; avoids full Services init."""
user = None
mandateId = None
featureInstanceId = None
workflow = None
interfaceDbApp = None
_interfaceDbComponent_val = None
interfaceDbChat = None
@property
def interfaceDbComponent(self):
"""Lazy-load interfaceDbComponent on first access (saves ~100300 ms when no files)."""
if self._interfaceDbComponent_val is None:
from modules.interfaces.interfaceDbManagement import getInterface as getComponentInterface
self._interfaceDbComponent_val = getComponentInterface(
self.user, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId
)
return self._interfaceDbComponent_val
chat = None
ai = None
streaming = None
featureCode = "chatbot"
allowedProviders = None
def _syncTemplateRolesToDb() -> int: def _syncTemplateRolesToDb() -> int:
""" """
Sync template roles and their AccessRules to the database. Sync template roles and their AccessRules to the database.

View file

@ -33,6 +33,17 @@ from modules.features.chatbot.interfaceFeatureChatbot import ChatbotConversation
from modules.features.chatbot import chatProcess from modules.features.chatbot import chatProcess
from modules.services.serviceStreaming import get_event_manager from modules.services.serviceStreaming import get_event_manager
# Pre-warm AI connectors when this router loads (before first request).
# Ensures connectors are ready; avoids 48 s delay on first chatbot message.
try:
import modules.aicore.aicoreModelRegistry # noqa: F401
from modules.aicore.aicoreModelRegistry import modelRegistry
modelRegistry.ensureConnectorsRegistered()
modelRegistry.refreshModels(force=True)
logging.getLogger(__name__).info("Chatbot router: AI connectors pre-warmed")
except Exception as e:
logging.getLogger(__name__).warning(f"Chatbot AI pre-warm failed: {e}")
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,12 +54,14 @@ router = APIRouter(
responses={404: {"description": "Not found"}} responses={404: {"description": "Not found"}}
) )
def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None): def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None, mandateId: Optional[str] = None):
"""Get chatbot interface with instance context.""" """Get chatbot interface with instance context.
mandateId = str(context.mandateId) if context.mandateId else None Pass mandateId when available (e.g. from _validateInstanceAccess) to ensure cache hit with getChatbotServices.
"""
effective_mandate = mandateId if mandateId is not None else (str(context.mandateId) if context.mandateId else None)
return interfaceDbChat.getInterface( return interfaceDbChat.getInterface(
context.user, context.user,
mandateId=mandateId, mandateId=effective_mandate,
featureInstanceId=instanceId featureInstanceId=instanceId
) )
@ -125,7 +138,7 @@ def get_chatbot_threads(
mandateId = _validateInstanceAccess(instanceId, context) mandateId = _validateInstanceAccess(instanceId, context)
try: try:
interfaceDbChat = _getServiceChat(context, instanceId) interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
if workflowId: if workflowId:
workflow = interfaceDbChat.getWorkflow(workflowId) workflow = interfaceDbChat.getWorkflow(workflowId)
@ -266,12 +279,14 @@ async def stream_chatbot_start(
async def event_stream(): async def event_stream():
"""Async generator for SSE events - pure event-driven streaming (no polling).""" """Async generator for SSE events - pure event-driven streaming (no polling)."""
try: try:
# Get interface for initial data and status checks # Yield keepalive immediately so client gets 200 + first byte fast (normal chatbot feel)
interfaceDbChat = _getServiceChat(context, instanceId) yield ": keepalive\n\n"
# Get current workflow to check if resuming and get current round # Use same mandateId as chatProcess so we hit interface cache (avoid duplicate DB init)
current_workflow = interfaceDbChat.getWorkflow(workflow.id) interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
current_round = current_workflow.currentRound if current_workflow else None
# Use workflow from chatProcess (no refetch)
current_round = workflow.currentRound if workflow else None
is_resuming = final_workflow_id is not None and current_round and current_round > 1 is_resuming = final_workflow_id is not None and current_round and current_round > 1
# Send initial chat data (exact format as chatData endpoint) - only once at start # Send initial chat data (exact format as chatData endpoint) - only once at start
@ -358,7 +373,7 @@ async def stream_chatbot_start(
event_type = event.get("type") event_type = event.get("type")
event_data = event.get("data", {}) event_data = event.get("data", {})
# Emit chatdata events (messages, logs, stats, status) in exact chatData format # Emit chatdata events (messages, logs, stats, status, chunk) in exact chatData format
if event_type == "chatdata" and event_data: if event_type == "chatdata" and event_data:
# Handle status events (transient UI feedback) # Handle status events (transient UI feedback)
if event_data.get("type") == "status": if event_data.get("type") == "status":
@ -368,6 +383,13 @@ async def stream_chatbot_start(
"label": event_data.get("label", "") "label": event_data.get("label", "")
} }
yield f"data: {json.dumps(status_item)}\n\n" yield f"data: {json.dumps(status_item)}\n\n"
elif event_data.get("type") == "chunk":
# Token chunks for ChatGPT-like streaming
chunk_item = {
"type": "chunk",
"content": event_data.get("content", "")
}
yield f"data: {json.dumps(chunk_item)}\n\n"
else: else:
# Emit other chatdata items (messages, logs, stats) in exact chatData format # Emit other chatdata items (messages, logs, stats) in exact chatData format
chatdata_item = event_data chatdata_item = event_data
@ -444,7 +466,7 @@ async def stop_chatbot(
try: try:
# Get chatbot interface with instance context # Get chatbot interface with instance context
interfaceDbChat = _getServiceChat(context, instanceId) interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
# Get workflow to verify it exists and belongs to this instance # Get workflow to verify it exists and belongs to this instance
workflow = interfaceDbChat.getWorkflow(workflowId) workflow = interfaceDbChat.getWorkflow(workflowId)
@ -519,8 +541,7 @@ def delete_chatbot(
mandateId = _validateInstanceAccess(instanceId, context) mandateId = _validateInstanceAccess(instanceId, context)
try: try:
# Get service center interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
interfaceDbChat = _getServiceChat(context, instanceId)
# Get workflow directly (interface already handles mandate filtering) # Get workflow directly (interface already handles mandate filtering)
workflow = interfaceDbChat.getWorkflow(workflowId) workflow = interfaceDbChat.getWorkflow(workflowId)

View file

@ -19,7 +19,7 @@ from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, ProcessingModeEnum from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, ProcessingModeEnum
from modules.datamodels.datamodelDocref import DocumentReferenceList, DocumentItemReference from modules.datamodels.datamodelDocref import DocumentReferenceList, DocumentItemReference
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
from modules.services import getInterface as getServices from modules.features.chatbot.mainChatbot import getChatbotServices
from modules.features.chatbot.chatbot import Chatbot from modules.features.chatbot.chatbot import Chatbot
from modules.features.chatbot.bridges.ai import AICenterChatModel, clear_workflow_allowed_providers from modules.features.chatbot.bridges.ai import AICenterChatModel, clear_workflow_allowed_providers
from modules.features.chatbot.bridges.memory import DatabaseCheckpointer from modules.features.chatbot.bridges.memory import DatabaseCheckpointer
@ -91,33 +91,28 @@ async def chatProcess(
ChatbotConversation instance ChatbotConversation instance
""" """
try: try:
# Get services with mandate and feature instance context # Get services from service center (only chat, ai, billing, streaming — avoids ~90ms legacy hub)
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId) services = getChatbotServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
services.featureCode = 'chatbot' services.featureCode = 'chatbot'
# Load instance config and apply allowedProviders for AI calls (conversation name + main chat) # Config and model warm run in background task — return stream ~23 s faster for normal feel
chatbot_config = await _load_chatbot_config(featureInstanceId) chatbot_config = None
if chatbot_config.model.allowedProviders:
services.allowedProviders = chatbot_config.model.allowedProviders # Reuse hub's interfaceDbChat (ChatObjects) - avoids duplicate DB init
logger.info(f"Chatbot instance {featureInstanceId}: restricting to providers {chatbot_config.model.allowedProviders}") interfaceDbChat = services.interfaceDbChat
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
interfaceDbChat = getChatbotInterface(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
# Create or load workflow (event_manager passed from route) # Create or load workflow (event_manager passed from route)
if workflowId: if workflowId:
workflow = interfaceDbChat.getWorkflow(workflowId) # Lightweight resume: minimal fetch + minimal update (no logs/messages load)
workflow = interfaceDbChat.getWorkflowMinimal(workflowId)
if not workflow: if not workflow:
raise ValueError(f"Workflow {workflowId} not found") raise ValueError(f"Workflow {workflowId} not found")
# Resume workflow: increment round number
new_round = workflow.currentRound + 1 new_round = workflow.currentRound + 1
interfaceDbChat.updateWorkflow(workflowId, { workflow = interfaceDbChat.updateWorkflowMinimal(workflowId, {
"status": "running", "status": "running",
"currentRound": new_round, "currentRound": new_round,
"lastActivity": getUtcTimestamp() "lastActivity": getUtcTimestamp()
}) })
workflow = interfaceDbChat.getWorkflow(workflowId)
logger.info(f"Resumed workflow {workflowId}, round incremented to {new_round}") logger.info(f"Resumed workflow {workflowId}, round incremented to {new_round}")
# Create event queue if it doesn't exist (for streaming) # Create event queue if it doesn't exist (for streaming)
@ -166,10 +161,7 @@ async def chatProcess(
# Create event queue for new workflow (for streaming) # Create event queue for new workflow (for streaming)
event_manager.create_queue(workflow.id) event_manager.create_queue(workflow.id)
# Reload workflow to get current message count
workflow = interfaceDbChat.getWorkflow(workflow.id)
# Process uploaded files and create ChatbotDocuments # Process uploaded files and create ChatbotDocuments
user_documents = [] user_documents = []
if userInput.listFileId and len(userInput.listFileId) > 0: if userInput.listFileId and len(userInput.listFileId) > 0:
@ -203,14 +195,19 @@ async def chatProcess(
except Exception as e: except Exception as e:
logger.error(f"Error processing file ID {fileId}: {e}", exc_info=True) logger.error(f"Error processing file ID {fileId}: {e}", exc_info=True)
# Store user message # Store user message (sequenceNr: for resume use message count, else len+1)
seq_nr = (
interfaceDbChat.getMessageCount(workflow.id) + 1
if workflowId
else len(workflow.messages) + 1
)
userMessageData: Dict[str, Any] = { userMessageData: Dict[str, Any] = {
"id": f"msg_{uuid.uuid4()}", "id": f"msg_{uuid.uuid4()}",
"conversationId": workflow.id, "conversationId": workflow.id,
"message": userInput.prompt, "message": userInput.prompt,
"role": "user", "role": "user",
"status": "first" if workflowId is None else "step", "status": "first" if workflowId is None else "step",
"sequenceNr": len(workflow.messages) + 1, "sequenceNr": seq_nr,
"publishedAt": getUtcTimestamp(), "publishedAt": getUtcTimestamp(),
"roundNumber": workflow.currentRound, "roundNumber": workflow.currentRound,
"taskNumber": 0, "taskNumber": 0,
@ -223,11 +220,12 @@ async def chatProcess(
userMessage = interfaceDbChat.createMessage(userMessageData, event_manager=None) userMessage = interfaceDbChat.createMessage(userMessageData, event_manager=None)
logger.info(f"Stored user message: {userMessage.id} with {len(user_documents)} document(s)") logger.info(f"Stored user message: {userMessage.id} with {len(user_documents)} document(s)")
# Update workflow status # Update workflow status (minimal update for lastActivity; resume already did this)
interfaceDbChat.updateWorkflow(workflow.id, { if not workflowId:
"status": "running", interfaceDbChat.updateWorkflowMinimal(workflow.id, {
"lastActivity": getUtcTimestamp() "status": "running",
}) "lastActivity": getUtcTimestamp()
})
# Pre-flight billing check before starting LangGraph (if mandateId present) # Pre-flight billing check before starting LangGraph (if mandateId present)
if mandateId: if mandateId:
@ -244,9 +242,6 @@ async def chatProcess(
config=chatbot_config, config=chatbot_config,
event_manager=event_manager event_manager=event_manager
)) ))
# Reload workflow to include new message
workflow = interfaceDbChat.getWorkflow(workflow.id)
return workflow return workflow
except Exception as e: except Exception as e:
@ -728,8 +723,7 @@ async def _convert_file_ids_to_document_references(
if not document_id: if not document_id:
try: try:
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface chatbotInterface = services.interfaceDbChat
chatbotInterface = getChatbotInterface(services.user, mandateId=services.mandateId, featureInstanceId=services.featureInstanceId)
documents = getRecordsetWithRBAC( documents = getRecordsetWithRBAC(
chatbotInterface.db, chatbotInterface.db,
ChatbotDocument, ChatbotDocument,
@ -928,6 +922,20 @@ async def _bridge_chatbot_events(
step="status" step="status"
) )
continue continue
# Handle token chunks for ChatGPT-like streaming (append to message as it's generated)
if event_type == "chunk":
content = event.get("content", "")
if content:
await event_manager.emit_event(
context_id=workflow_id,
event_type="chatdata",
data={"type": "chunk", "content": content},
event_category="chat",
message="Token chunk",
step="chunk"
)
continue
# Handle final response # Handle final response
if event_type == "final": if event_type == "final":
@ -1117,9 +1125,59 @@ async def _bridge_chatbot_events(
clear_workflow_allowed_providers(workflow_id) clear_workflow_allowed_providers(workflow_id)
def _load_chatbot_config_sync(featureInstanceId: Optional[str]) -> ChatbotConfig:
"""
Load chatbot configuration from FeatureInstance (database). Sync version for use in executor.
Args:
featureInstanceId: Feature instance ID to load config from
Returns:
ChatbotConfig instance
"""
if not featureInstanceId:
raise ValueError("featureInstanceId is required to load chatbot config")
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceFeatures import getFeatureInterface
rootInterface = getRootInterface()
featureInterface = getFeatureInterface(rootInterface.db)
instance = featureInterface.getFeatureInstance(featureInstanceId)
if not instance:
raise ValueError(f"FeatureInstance {featureInstanceId} not found")
logger.info(f"Loading chatbot config from FeatureInstance {featureInstanceId}")
return load_chatbot_config_from_instance(instance)
def _warm_model_registry_cache_sync(
currentUser: "User",
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
) -> None:
"""
Pre-warm getAvailableModels cache so planner/agent model selection is a cache hit.
Uses mandateId/featureInstanceId for faster RBAC (fewer roles to load).
Runs in executor to avoid blocking event loop.
"""
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.aicore.aicoreModelRegistry import modelRegistry
root = getRootInterface()
modelRegistry.getAvailableModels(
currentUser=currentUser,
rbacInstance=root.rbac,
mandateId=mandateId,
featureInstanceId=featureInstanceId,
)
async def _load_chatbot_config(featureInstanceId: Optional[str]) -> ChatbotConfig: async def _load_chatbot_config(featureInstanceId: Optional[str]) -> ChatbotConfig:
""" """
Load chatbot configuration from FeatureInstance (database). Load chatbot configuration from FeatureInstance (database).
Runs in thread pool to avoid blocking event loop on DB I/O.
Args: Args:
featureInstanceId: Feature instance ID to load config from featureInstanceId: Feature instance ID to load config from
@ -1134,20 +1192,7 @@ async def _load_chatbot_config(featureInstanceId: Optional[str]) -> ChatbotConfi
raise ValueError("featureInstanceId is required to load chatbot config") raise ValueError("featureInstanceId is required to load chatbot config")
try: try:
# Import here to avoid circular imports return await asyncio.to_thread(_load_chatbot_config_sync, featureInstanceId)
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceFeatures import getFeatureInterface
# Get feature instance from database
rootInterface = getRootInterface()
featureInterface = getFeatureInterface(rootInterface.db)
instance = featureInterface.getFeatureInstance(featureInstanceId)
if not instance:
raise ValueError(f"FeatureInstance {featureInstanceId} not found")
logger.info(f"Loading chatbot config from FeatureInstance {featureInstanceId}")
return load_chatbot_config_from_instance(instance)
except ValueError: except ValueError:
raise raise
except Exception as e: except Exception as e:
@ -1250,9 +1295,25 @@ async def _processChatbotMessageLangGraph(
event_manager: Event manager for streaming (passed from chatProcess) event_manager: Event manager for streaming (passed from chatProcess)
""" """
try: try:
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface # Start config + model cache warm in parallel (planner/agent need cache hit to avoid 23 s per call)
interfaceDbChat = getChatbotInterface(currentUser, mandateId=services.mandateId, featureInstanceId=featureInstanceId) config_task = asyncio.create_task(_load_chatbot_config(featureInstanceId)) if config is None else None
warm_task = asyncio.create_task(asyncio.to_thread(
_warm_model_registry_cache_sync, currentUser, services.mandateId, featureInstanceId
))
# Emit first status immediately so stream feels responsive
await event_manager.emit_event(
context_id=workflowId,
event_type="chatdata",
data={"type": "status", "label": "Starte..."},
event_category="chat",
message="Status update",
step="status",
)
# Reuse interfaceDbChat from services (ChatObjects) - avoids duplicate DB init
interfaceDbChat = services.interfaceDbChat
# Reload workflow to get current messages # Reload workflow to get current messages
workflow = interfaceDbChat.getWorkflow(workflowId) workflow = interfaceDbChat.getWorkflow(workflowId)
if not workflow: if not workflow:
@ -1272,19 +1333,10 @@ async def _processChatbotMessageLangGraph(
logger.info(f"Workflow {workflowId} was stopped, aborting processing") logger.info(f"Workflow {workflowId} was stopped, aborting processing")
return return
# Emit synthetic status for real-time UI feedback # Await config and model cache warm (planner gets cache hit, saves ~23 s)
await event_manager.emit_event( if config_task is not None:
context_id=workflowId, config = await config_task
event_type="chatdata", await warm_task
data={"type": "status", "label": "Lade Konfiguration..."},
event_category="chat",
message="Status update",
step="status"
)
# Load configuration if not passed (e.g. when resuming)
if config is None:
config = await _load_chatbot_config(featureInstanceId)
# Replace {{DATE}} placeholder in system prompt # Replace {{DATE}} placeholder in system prompt
from datetime import datetime from datetime import datetime
@ -1310,25 +1362,30 @@ async def _processChatbotMessageLangGraph(
processing_mode=processing_mode, processing_mode=processing_mode,
billing_callback=billing_callback, billing_callback=billing_callback,
workflow_id=workflowId, workflow_id=workflowId,
allowed_providers=allowed_providers allowed_providers=allowed_providers,
mandate_id=services.mandateId,
feature_instance_id=featureInstanceId,
)
# Fast planner model (gpt-4o-mini etc.) for routing - saves ~1-2 s on first response
planner_model = AICenterChatModel(
user=currentUser,
operation_type=operation_type,
processing_mode=processing_mode,
billing_callback=billing_callback,
workflow_id=workflowId,
allowed_providers=allowed_providers,
prefer_fast_model=True,
mandate_id=services.mandateId,
feature_instance_id=featureInstanceId,
) )
# Emit synthetic status for real-time UI feedback # Create memory/checkpointer (reuse interface to avoid extra DB init)
await event_manager.emit_event(
context_id=workflowId,
event_type="chatdata",
data={"type": "status", "label": "Bereite Chat vor..."},
event_category="chat",
message="Status update",
step="status"
)
# Create memory/checkpointer (uses chatbot's own DB via interfaceFeatureChatbot)
memory = DatabaseCheckpointer( memory = DatabaseCheckpointer(
user=currentUser, user=currentUser,
workflow_id=workflowId, workflow_id=workflowId,
mandateId=services.mandateId, mandateId=services.mandateId,
featureInstanceId=featureInstanceId featureInstanceId=featureInstanceId,
interface=interfaceDbChat,
) )
# Create chatbot instance with config for dynamic tool configuration # Create chatbot instance with config for dynamic tool configuration
@ -1338,7 +1395,8 @@ async def _processChatbotMessageLangGraph(
system_prompt=system_prompt, system_prompt=system_prompt,
workflow_id=workflowId, workflow_id=workflowId,
config=config, config=config,
event_manager=event_manager event_manager=event_manager,
planner_model=planner_model,
) )
# Emit synthetic status for real-time UI feedback # Emit synthetic status for real-time UI feedback

View file

@ -15,7 +15,7 @@ from typing import Dict, Any, List, Optional, Union
from passlib.context import CryptContext from passlib.context import CryptContext
import uuid import uuid
from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
from modules.shared.configuration import APP_CONFIG from modules.shared.configuration import APP_CONFIG
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
from modules.interfaces.interfaceBootstrap import initBootstrap from modules.interfaces.interfaceBootstrap import initBootstrap
@ -144,8 +144,7 @@ class AppObjects:
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
# Create database connector directly self.db = _get_cached_connector(
self.db = DatabaseConnector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,

View file

@ -12,7 +12,7 @@ import hashlib
import math import math
from typing import Dict, Any, List, Optional, Union from typing import Dict, Any, List, Optional, Union
from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
from modules.security.rbac import RbacClass from modules.security.rbac import RbacClass
from modules.datamodels.datamodelRbac import AccessRuleContext from modules.datamodels.datamodelRbac import AccessRuleContext
@ -131,8 +131,7 @@ class ComponentObjects:
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
# Create database connector directly self.db = _get_cached_connector(
self.db = DatabaseConnector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
dbUser=dbUser, dbUser=dbUser,

View file

@ -11,7 +11,7 @@ Multi-Tenant Design:
""" """
import logging import logging
from typing import List, Optional, TYPE_CHECKING from typing import Dict, List, Optional, TYPE_CHECKING
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
from modules.datamodels.datamodelMembership import ( from modules.datamodels.datamodelMembership import (
@ -163,6 +163,52 @@ class RbacClass:
return permissions return permissions
def checkResourceAccessBulk(
self,
user: User,
resourcePaths: List[str],
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None
) -> Dict[str, bool]:
"""
Check view access for multiple RESOURCE paths in one pass.
Uses same logic as getUserPermissions but batches DB access.
Returns {path: has_view}.
"""
result = {p: False for p in resourcePaths}
if not resourcePaths:
return result
# SysAdmin bypass
if hasattr(user, "isSysAdmin") and user.isSysAdmin:
return {p: True for p in resourcePaths}
roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId)
if not roleIds:
return result
rulesWithPriority = self._getRulesForRoleIds(
roleIds, AccessRuleContext.RESOURCE, mandateId, featureInstanceId
)
for path in resourcePaths:
rolePermissions = {}
for priority, rule in rulesWithPriority:
if not self._ruleMatchesItem(rule, path):
continue
roleId = rule.roleId
itemSpecificity = self._getItemSpecificity(rule, path)
if roleId not in rolePermissions:
rolePermissions[roleId] = (priority, itemSpecificity, rule)
else:
existingPriority, existingSpecificity, _ = rolePermissions[roleId]
if priority > existingPriority or (
priority == existingPriority and itemSpecificity > existingSpecificity
):
rolePermissions[roleId] = (priority, itemSpecificity, rule)
highestPriority = max((p for p, _, _ in rolePermissions.values()), default=0)
for _, (priority, _, rule) in rolePermissions.items():
if priority >= highestPriority and rule.view:
result[path] = True
break
return result
def _getRoleIdsForUser( def _getRoleIdsForUser(
self, self,
user: User, user: User,