From 6dc2afafb921c8a937965686657bfecb1b294154 Mon Sep 17 00:00:00 2001 From: Ida Dittrich Date: Fri, 6 Mar 2026 13:46:54 +0100 Subject: [PATCH] fix:performance improvements - app.py: Pre-warm AI connectors at module load and in lifespan - aicoreModelRegistry.py: Connector discovery cache, getAvailableModels cache, bulk RBAC, eager prewarm - connectorDbPostgre.py: Connector cache, contextvars for userId, eviction (max 32) - chatbot: Uses _get_cached_connector, Service center integration, BillingService exceptions, BillingService exceptions instead of direct imports - interfaceDbApp.py: Uses _get_cached_connector - interfaceDbManagement.py: Uses _get_cached_connector - security/rbac.py: Adds checkResourceAccessBulk --- app.py | 19 +- modules/aicore/aicoreModelRegistry.py | 120 +++-- modules/connectors/connectorDbPostgre.py | 81 ++- modules/features/chatbot/bridges/ai.py | 132 ++++- modules/features/chatbot/bridges/memory.py | 135 ++++- modules/features/chatbot/chatbot.py | 490 +++++++++++++++++- .../chatbot/interfaceFeatureChatbot.py | 129 ++++- modules/features/chatbot/mainChatbot.py | 80 ++- .../features/chatbot/routeFeatureChatbot.py | 53 +- modules/features/chatbot/service.py | 212 +++++--- modules/interfaces/interfaceDbApp.py | 5 +- modules/interfaces/interfaceDbManagement.py | 5 +- modules/security/rbac.py | 48 +- 13 files changed, 1317 insertions(+), 192 deletions(-) diff --git a/app.py b/app.py index 68b51af0..c112b869 100644 --- a/app.py +++ b/app.py @@ -280,12 +280,29 @@ initLogging() logger = logging.getLogger(__name__) instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") +# Pre-warm AI connectors on process load (before lifespan). Critical for chatbot latency. +try: + import modules.aicore.aicoreModelRegistry # noqa: F401 + logger.info("AI connectors pre-warm (app load) triggered") +except Exception as e: + logging.getLogger(__name__).warning(f"AI pre-warm at app load failed: {e}") # Define lifespan context manager for application startup/shutdown events @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Application is starting up") + # --- Pre-warm AI connectors FIRST (before any other startup work) --- + # Avoids 4–8 s latency on first chatbot request; must run before first use. + try: + import modules.aicore.aicoreModelRegistry # noqa: F401 - triggers eager pre-warm + from modules.aicore.aicoreModelRegistry import modelRegistry + modelRegistry.ensureConnectorsRegistered() + modelRegistry.refreshModels(force=True) + logger.info("AI connectors and model registry pre-warmed") + except Exception as e: + logger.warning(f"AI pre-warm failed: {e}") + # Bootstrap database if needed (creates initial users, mandates, roles, etc.) # This must happen before getting root interface from modules.security.rootAccess import getRootDbAppConnector @@ -333,7 +350,7 @@ async def lifespan(app: FastAPI): # Register audit log cleanup scheduler from modules.shared.auditLogger import registerAuditLogCleanupScheduler registerAuditLogCleanupScheduler() - + # Ensure billing settings and accounts exist for all mandates try: from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface diff --git a/modules/aicore/aicoreModelRegistry.py b/modules/aicore/aicoreModelRegistry.py index 8fd0e284..844922a2 100644 --- a/modules/aicore/aicoreModelRegistry.py +++ b/modules/aicore/aicoreModelRegistry.py @@ -8,7 +8,8 @@ Implements plugin-like architecture for connector discovery. import logging import importlib import os -from typing import Dict, List, Optional, Any +import time +from typing import Dict, List, Optional, Any, Tuple from modules.datamodels.datamodelAi import AiModel from .aicoreBase import BaseConnectorAi from modules.datamodels.datamodelUam import User @@ -31,6 +32,9 @@ class ModelRegistry: self._lastRefresh: Optional[float] = None self._refreshInterval: float = 300.0 # 5 minutes self._connectorsInitialized: bool = False + self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call + self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts) + self._getAvailableModelsCacheTtl: float = 30.0 # seconds def registerConnector(self, connector: BaseConnectorAi): """Register a connector and collect its models.""" @@ -68,34 +72,38 @@ class ModelRegistry: raise def discoverConnectors(self) -> List[BaseConnectorAi]: - """Auto-discover connectors by scanning aicorePlugin*.py files.""" + """Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use.""" + if self._discoveredConnectorsCache is not None: + return self._discoveredConnectorsCache + connectors = [] connectorDir = os.path.dirname(__file__) - + # Scan for connector files for filename in os.listdir(connectorDir): if filename.startswith('aicorePlugin') and filename.endswith('.py'): moduleName = filename[:-3] # Remove .py extension - + try: # Import the module module = importlib.import_module(f'modules.aicore.{moduleName}') - + # Find connector classes (classes that inherit from BaseConnectorAi) for attrName in dir(module): attr = getattr(module, attrName) - if (isinstance(attr, type) and - issubclass(attr, BaseConnectorAi) and + if (isinstance(attr, type) and + issubclass(attr, BaseConnectorAi) and attr != BaseConnectorAi): - + # Instantiate the connector connector = attr() connectors.append(connector) logger.info(f"Discovered connector: {connector.getConnectorType()}") - + except Exception as e: logger.warning(f"Failed to discover connector from {filename}: {e}") - + + self._discoveredConnectorsCache = connectors return connectors def ensureConnectorsRegistered(self): @@ -175,24 +183,49 @@ class ModelRegistry: self.refreshModels() return [model for model in self._models.values() if model.priority == priority] - def getAvailableModels(self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> List[AiModel]: + def getAvailableModels( + self, + currentUser: Optional[User] = None, + rbacInstance: Optional[RbacClass] = None, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None + ) -> List[AiModel]: """Get only available models, optionally filtered by RBAC permissions. + Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call. Args: currentUser: Optional user object for RBAC filtering rbacInstance: Optional RBAC instance for permission checks + mandateId: Optional mandate context for faster RBAC (loads fewer roles) + featureInstanceId: Optional feature instance for RBAC context Returns: List of available models (filtered by RBAC if user provided) """ self.refreshModels() + cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0) + now = time.time() + if cache_key in self._getAvailableModelsCache: + cached_models, cached_ts = self._getAvailableModelsCache[cache_key] + if now - cached_ts < self._getAvailableModelsCacheTtl: + logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...") + return cached_models + allModels = list(self._models.values()) availableModels = [model for model in allModels if model.isAvailable] - - # Apply RBAC filtering if user and RBAC instance provided + + # Apply RBAC filtering if user and RBAC instance provided (batch check for performance) if currentUser and rbacInstance: - availableModels = self._filterModelsByRbac(availableModels, currentUser, rbacInstance) - + availableModels = self._filterModelsByRbac( + availableModels, currentUser, rbacInstance, mandateId, featureInstanceId + ) + + self._getAvailableModelsCache[cache_key] = (availableModels, now) + # Prune expired entries to avoid unbounded growth + expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl] + for k in expired: + del self._getAvailableModelsCache[k] + unavailableCount = len(allModels) - len(availableModels) if unavailableCount > 0: unavailableModels = [m.name for m in allModels if not m.isAvailable] @@ -200,32 +233,33 @@ class ModelRegistry: logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}") return availableModels - def _filterModelsByRbac(self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass) -> List[AiModel]: - """Filter models based on RBAC permissions. - - Args: - models: List of models to filter - currentUser: Current user object - rbacInstance: RBAC instance for permission checks - - Returns: - Filtered list of models that user has access to - """ + def _filterModelsByRbac( + self, + models: List[AiModel], + currentUser: User, + rbacInstance: RbacClass, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None + ) -> List[AiModel]: + """Filter models based on RBAC permissions. Uses bulk check for performance.""" + paths = [] + model_paths = {} # model -> (connector_path, model_path) + for model in models: + connector_path = f"ai.model.{model.connectorType}" + model_path = f"ai.model.{model.connectorType}.{model.displayName}" + paths.extend([connector_path, model_path]) + model_paths[id(model)] = (connector_path, model_path) + # Single bulk RBAC call instead of 2*N per-model calls + access = rbacInstance.checkResourceAccessBulk( + currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId + ) filteredModels = [] for model in models: - # Check access at both connector level and model level - connectorResourcePath = f"ai.model.{model.connectorType}" - modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}" - - # User needs access to either connector (all models) or specific model - hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath) - hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath) - - if hasConnectorAccess or hasModelAccess: + connector_path, model_path = model_paths[id(model)] + if access.get(connector_path, False) or access.get(model_path, False): filteredModels.append(model) else: logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})") - return filteredModels def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]: @@ -305,3 +339,17 @@ class ModelRegistry: # Global registry instance modelRegistry = ModelRegistry() + +# Eager pre-warm on first import: ensures connectors are ready in this process. +# Critical for chatbot performance — avoids 4–8 s latency on first request. +# Runs when this module is first imported (lifespan or first chatbot request). +def _eager_prewarm() -> None: + try: + modelRegistry.ensureConnectorsRegistered() + modelRegistry.refreshModels(force=True) + logger.info("AI connectors and model registry pre-warmed (module load)") + except Exception as e: + logger.warning(f"AI eager pre-warm skipped: {e}") + + +_eager_prewarm() diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index 6c89a85f..c4457117 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -1,5 +1,6 @@ # Copyright (c) 2025 Patrick Motsch # All rights reserved. +import contextvars import psycopg2 import psycopg2.extras import logging @@ -99,7 +100,56 @@ def _get_model_fields(model_class) -> Dict[str, str]: return fields -# No caching needed with proper database +# Cache connectors by (host, database, port) to avoid duplicate inits for same database. +# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via +# contextvars to avoid races when concurrent requests share the same connector. +_MAX_CACHED_CONNECTORS = 32 +_connector_cache: Dict[tuple, "DatabaseConnector"] = {} +_connector_cache_order: List[tuple] = [] # FIFO order for eviction +_connector_cache_lock = threading.Lock() +_current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "db_connector_user_id", default=None +) + + +def _get_cached_connector( + dbHost: str, + dbDatabase: str, + dbUser: str = None, + dbPassword: str = None, + dbPort: int = None, + userId: str = None, +) -> "DatabaseConnector": + """Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits. + Uses contextvars for userId so concurrent requests sharing the same connector get correct _createdBy/_modifiedBy. + """ + port = int(dbPort) if dbPort is not None else 5432 + key = (dbHost, dbDatabase, port) + with _connector_cache_lock: + if key not in _connector_cache: + # Evict oldest if at capacity + while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order: + oldest_key = _connector_cache_order.pop(0) + if oldest_key in _connector_cache: + try: + _connector_cache[oldest_key].close() + except Exception as e: + logger.warning(f"Error closing evicted connector: {e}") + del _connector_cache[oldest_key] + _connector_cache[key] = DatabaseConnector( + dbHost=dbHost, + dbDatabase=dbDatabase, + dbUser=dbUser, + dbPassword=dbPassword, + dbPort=dbPort, + userId=userId, + ) + _connector_cache_order.append(key) + conn = _connector_cache[key] + # Set request-scoped userId via contextvar (avoids mutating shared connector) + if userId is not None: + _current_user_id.set(userId) + return conn class DatabaseConnector: @@ -645,24 +695,22 @@ class DatabaseConnector: if "id" in record and str(record["id"]) != recordId: raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}") - # Add metadata + # Add metadata - use contextvar for request-scoped userId when sharing connector + effective_user_id = _current_user_id.get() + if effective_user_id is None: + effective_user_id = self.userId currentTime = getUtcTimestamp() # Set _createdAt and _createdBy if this is a new record (record doesn't have _createdAt) if "_createdAt" not in record: record["_createdAt"] = currentTime - # Only set _createdBy if userId is valid (not None or empty string) - if self.userId: - record["_createdBy"] = self.userId - # No warning - empty userId is normal during bootstrap - # Also ensure _createdBy is set even if _createdAt exists but _createdBy is missing/empty + if effective_user_id: + record["_createdBy"] = effective_user_id elif "_createdBy" not in record or not record.get("_createdBy"): - if self.userId: - record["_createdBy"] = self.userId - # No warning - empty userId is normal during bootstrap - # Always update modification metadata + if effective_user_id: + record["_createdBy"] = effective_user_id record["_modifiedAt"] = currentTime - if self.userId: - record["_modifiedBy"] = self.userId + if effective_user_id: + record["_modifiedBy"] = effective_user_id with self.connection.cursor() as cursor: self._save_record(cursor, table, recordId, record, model_class) @@ -782,12 +830,13 @@ class DatabaseConnector: return False def updateContext(self, userId: str) -> None: - """Updates the context of the database connector.""" + """Updates the context of the database connector. + Sets both instance userId and contextvar for request-scoped use when connector is shared. + """ if userId is None: raise ValueError("userId must be provided") - self.userId = userId - # No cache to clear - database handles data consistency + _current_user_id.set(userId) # Public API diff --git a/modules/features/chatbot/bridges/ai.py b/modules/features/chatbot/bridges/ai.py index 283a9e4e..a06668c8 100644 --- a/modules/features/chatbot/bridges/ai.py +++ b/modules/features/chatbot/bridges/ai.py @@ -38,9 +38,10 @@ from modules.datamodels.datamodelUam import User logger = logging.getLogger(__name__) -# Workflow-level store for allowed_providers (survives LangGraph/bind_tools execution context -# where instance attributes may be lost when model is wrapped or serialized) +# Workflow-level store for allowed_providers and RBAC context (survives LangGraph/bind_tools +# execution context where instance attributes may be lost when model is wrapped or serialized) _workflow_allowed_providers: Dict[str, List[str]] = {} +_workflow_rbac_context: Dict[str, tuple] = {} # workflow_id -> (mandateId, featureInstanceId) def clear_workflow_allowed_providers(workflow_id: str) -> None: @@ -62,11 +63,14 @@ class AICenterChatModel(BaseChatModel): billing_callback: Optional[Callable[[AiCallResponse], None]] = None, workflow_id: Optional[str] = None, allowed_providers: Optional[List[str]] = None, + prefer_fast_model: bool = False, + mandate_id: Optional[str] = None, + feature_instance_id: Optional[str] = None, **kwargs ): """ Initialize the AI center chat model bridge. - + Args: user: Current user for RBAC and model selection operation_type: Operation type for model selection @@ -74,6 +78,7 @@ class AICenterChatModel(BaseChatModel): billing_callback: Optional callback invoked after each _agenerate with AiCallResponse for billing workflow_id: Optional workflow/conversation ID for billing context allowed_providers: Optional list of allowed provider connector types (empty/None = all) + prefer_fast_model: When True, strongly prefer faster models (e.g. gpt-4o-mini for planner) **kwargs: Additional arguments passed to BaseChatModel """ super().__init__(**kwargs) @@ -85,9 +90,14 @@ class AICenterChatModel(BaseChatModel): object.__setattr__(self, "_billing_callback", billing_callback) object.__setattr__(self, "_workflow_id", workflow_id) object.__setattr__(self, "_allowed_providers", allowed_providers or []) + object.__setattr__(self, "_prefer_fast_model", prefer_fast_model) + object.__setattr__(self, "_mandate_id", mandate_id) + object.__setattr__(self, "_feature_instance_id", feature_instance_id) # Store in workflow-level registry so it survives when instance attrs are lost (e.g. bind_tools) if workflow_id and allowed_providers: _workflow_allowed_providers[workflow_id] = list(allowed_providers) + if workflow_id and (mandate_id is not None or feature_instance_id is not None): + _workflow_rbac_context[workflow_id] = (mandate_id, feature_instance_id) @property def _llm_type(self) -> str: @@ -129,17 +139,25 @@ class AICenterChatModel(BaseChatModel): # Get available models with RBAC filtering # Use cached/singleton interfaces for better performance from modules.interfaces.interfaceDbApp import getRootInterface - + + workflow_id = getattr(self, "_workflow_id", None) rootInterface = getRootInterface() rbac_instance = rootInterface.rbac - + + mandate_id = getattr(self, "_mandate_id", None) + feature_instance_id = getattr(self, "_feature_instance_id", None) + if workflow_id and (mandate_id is None and feature_instance_id is None): + ctx = _workflow_rbac_context.get(workflow_id) + if ctx: + mandate_id, feature_instance_id = ctx available_models = modelRegistry.getAvailableModels( currentUser=self.user, - rbacInstance=rbac_instance + rbacInstance=rbac_instance, + mandateId=mandate_id, + featureInstanceId=feature_instance_id, ) # Allowed providers: instance attr or workflow store (lost in LangGraph/bind_tools context) - workflow_id = getattr(self, '_workflow_id', None) allowed = ( (_workflow_allowed_providers.get(workflow_id) if workflow_id else None) or getattr(self, '_allowed_providers', None) @@ -155,7 +173,8 @@ class AICenterChatModel(BaseChatModel): options = AiCallOptions( operationType=self.operation_type, processingMode=self.processing_mode, - allowedProviders=allowed if allowed else None + allowedProviders=allowed if allowed else None, + preferFastModel=getattr(self, "_prefer_fast_model", False), ) # Select model @@ -246,7 +265,97 @@ class AICenterChatModel(BaseChatModel): # Run the async method synchronously return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)) - + + async def _call_openai_streaming( + self, + ai_messages: List[dict], + run_manager: Optional[Any], + model_call: "AiModelCall", + input_bytes: int, + start_time: float, + ) -> "AiModelResponse": + """Call OpenAI/Ollama with stream=True, emit tokens via run_manager, return full response.""" + import httpx + import json as _json + from modules.shared.configuration import APP_CONFIG + + if self._selected_model.connectorType == "openai": + api_url = getattr(self._selected_model, "apiUrl", None) or "https://api.openai.com/v1/chat/completions" + api_key = APP_CONFIG.get("Connector_AiOpenai_API_SECRET") + if not api_key: + raise ValueError("OpenAI API key not configured") + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + ollama_model = self._selected_model.name + else: + base_url = getattr(self._selected_model, "apiUrl", "").replace("/api/analyze", "") + api_url = f"{base_url.rstrip('/')}/v1/chat/completions" + api_key = APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET") + headers = {"Content-Type": "application/json"} + if api_key: + headers["X-API-Key"] = api_key + ollama_model = getattr(self._selected_model, "version", None) or self._selected_model.name + + payload = { + "model": ollama_model, + "messages": ai_messages, + "temperature": self._selected_model.temperature, + "max_tokens": self._selected_model.maxTokens, + "stream": True, + } + content_parts: List[str] = [] + async with httpx.AsyncClient(timeout=600.0) as client: + async with client.stream("POST", api_url, headers=headers, json=payload) as resp: + if resp.status_code != 200: + raise ValueError(f"OpenAI stream error: {resp.status_code} - {await resp.aread()}") + buffer = "" + async for chunk in resp.aiter_text(): + buffer += chunk + while "\n" in buffer or "\r\n" in buffer: + line, _, buffer = buffer.partition("\n") + line = line.strip() + if line.startswith("data: "): + data_str = line[6:].strip() + if data_str == "[DONE]": + break + try: + data = _json.loads(data_str) + choices = data.get("choices") or [] + if choices: + delta = choices[0].get("delta") or {} + token = delta.get("content") or "" + if token and run_manager and hasattr(run_manager, "on_llm_new_token"): + run_manager.on_llm_new_token(token) + content_parts.append(token) + except _json.JSONDecodeError: + pass + + content = "".join(content_parts) + processing_time = time.time() - start_time + output_bytes = len(content.encode("utf-8")) + price_chf = 0.0 + if getattr(self._selected_model, "calculatepriceCHF", None): + try: + price_chf = self._selected_model.calculatepriceCHF(processing_time, input_bytes, output_bytes) + except Exception: + pass + billing_callback = getattr(self, "_billing_callback", None) + if billing_callback: + try: + billing_callback(AiCallResponse( + content=content, + modelName=self._selected_model.name, + provider=self._selected_model.connectorType or "unknown", + priceCHF=price_chf, + processingTime=processing_time, + bytesSent=input_bytes, + bytesReceived=output_bytes, + errorCount=0, + )) + except Exception as e: + logger.error(f"Billing callback error: {e}") + + return AiModelResponse(content=content, success=True, modelId=self._selected_model.name, metadata={}) + async def _agenerate( self, messages: List[BaseMessage], @@ -484,6 +593,11 @@ class AICenterChatModel(BaseChatModel): "tool_calls": tool_calls, }, ) + elif not tools and self._selected_model.connectorType in ("openai", "privatellm"): + # Streaming path for OpenAI/Ollama without tools (ChatGPT-like token streaming) + response = await self._call_openai_streaming( + ai_messages, run_manager, model_call, input_bytes, start_time + ) else: # No tools or not OpenAI - use connector normally if not self._selected_model.functionCall: diff --git a/modules/features/chatbot/bridges/memory.py b/modules/features/chatbot/bridges/memory.py index 234dc041..2d9251c1 100644 --- a/modules/features/chatbot/bridges/memory.py +++ b/modules/features/chatbot/bridges/memory.py @@ -5,6 +5,7 @@ Custom LangGraph checkpointer using existing database interface. Maps LangGraph state to existing message storage format. """ +import contextvars import logging import uuid from typing import Any, Dict, List, Optional, Tuple, NamedTuple @@ -47,19 +48,30 @@ class DatabaseCheckpointer(BaseCheckpointSaver): Maps LangGraph thread_id to conversation.id; stores messages via interface (workflowId maps to conversationId). """ - def __init__(self, user: User, workflow_id: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None): + def __init__( + self, + user: User, + workflow_id: str, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None, + *, + interface=None, + ): """ Initialize the database checkpointer. - + Args: user: Current user for database access workflow_id: Workflow ID (maps to LangGraph thread_id) mandateId: Mandate ID for proper data isolation featureInstanceId: Feature instance ID for proper data isolation + interface: Optional pre-created chatbot interface (avoids extra getInterface + DB init) """ self.user = user self.workflow_id = workflow_id - self.interface = getChatbotInterface(user, mandateId=mandateId, featureInstanceId=featureInstanceId) + self.interface = interface if interface is not None else getChatbotInterface( + user, mandateId=mandateId, featureInstanceId=featureInstanceId + ) def _convert_langchain_to_db_message( self, @@ -445,3 +457,120 @@ class DatabaseCheckpointer(BaseCheckpointSaver): # Not implemented - using aput() instead # This method is called by LangGraph but we handle writes through aput() pass + + +# ContextVar for per-request checkpointer (used by CheckpointerResolver for graph caching) +_current_checkpointer: contextvars.ContextVar[Optional[BaseCheckpointSaver]] = contextvars.ContextVar( + "chatbot_current_checkpointer", default=None +) + + +def set_checkpointer(checkpointer: BaseCheckpointSaver) -> contextvars.Token: + """Set the current request's checkpointer. Returns token to reset later.""" + return _current_checkpointer.set(checkpointer) + + +def reset_checkpointer(token: contextvars.Token) -> None: + """Reset checkpointer to prior value. Safe when called from a different async context.""" + try: + _current_checkpointer.reset(token) + except ValueError: + # Token was created in a different context (e.g. after yield, generator cleanup) + pass + + +class CheckpointerResolver(BaseCheckpointSaver): + """ + Delegating checkpointer that reads the real checkpointer from context. + Used for graph caching: the compiled graph uses this resolver; at invoke time + the per-request checkpointer is set via set_checkpointer(). + """ + + def _get_checkpointer(self) -> BaseCheckpointSaver: + cp = _current_checkpointer.get() + if cp is None: + raise RuntimeError( + "CheckpointerResolver: no checkpointer in context. " + "Call set_checkpointer() before invoking the cached graph." + ) + return cp + + def put( + self, + config: Dict[str, Any], + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: Dict[str, int], + ) -> None: + self._get_checkpointer().put(config, checkpoint, metadata, new_versions) + + def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]: + return self._get_checkpointer().get(config) + + def list( + self, + config: Dict[str, Any], + filter: Optional[Dict[str, Any]] = None, + before: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[Checkpoint]: + return self._get_checkpointer().list(config, filter, before, limit) + + def put_writes( + self, + config: Dict[str, Any], + writes: List[Tuple[str, Any]], + task_id: str, + ) -> None: + self._get_checkpointer().put_writes(config, writes, task_id) + + async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]: + inner = self._get_checkpointer() + if hasattr(inner, "aget_tuple"): + return await inner.aget_tuple(config) + checkpoint = inner.get(config) + if checkpoint: + metadata: CheckpointMetadata = {"step": 0} + return CheckpointTuple( + config=config, + checkpoint=checkpoint, + metadata=metadata, + parent_config=None, + pending_writes=None, + ) + return None + + async def aput( + self, + config: Dict[str, Any], + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: Dict[str, int], + ) -> None: + inner = self._get_checkpointer() + if hasattr(inner, "aput"): + await inner.aput(config, checkpoint, metadata, new_versions) + else: + inner.put(config, checkpoint, metadata, new_versions) + + async def alist( + self, + config: Dict[str, Any], + filter: Optional[Dict[str, Any]] = None, + before: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[Checkpoint]: + inner = self._get_checkpointer() + if hasattr(inner, "alist"): + return await inner.alist(config, filter, before, limit) + return inner.list(config, filter, before, limit) + + async def aput_writes( + self, + config: Dict[str, Any], + writes: List[Tuple[str, Any]], + task_id: str, + ) -> None: + inner = self._get_checkpointer() + if hasattr(inner, "aput_writes"): + await inner.aput_writes(config, writes, task_id) diff --git a/modules/features/chatbot/chatbot.py b/modules/features/chatbot/chatbot.py index 2b983372..e91e4f99 100644 --- a/modules/features/chatbot/chatbot.py +++ b/modules/features/chatbot/chatbot.py @@ -2,8 +2,10 @@ # All rights reserved. """Chatbot domain logic.""" +import contextvars import re import logging +import threading from dataclasses import dataclass, field from typing import Annotated, AsyncIterator, Any, List, Optional, TYPE_CHECKING from pydantic import BaseModel @@ -21,7 +23,12 @@ from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph from modules.features.chatbot.bridges.ai import AICenterChatModel -from modules.features.chatbot.bridges.memory import DatabaseCheckpointer +from modules.features.chatbot.bridges.memory import ( + CheckpointerResolver, + DatabaseCheckpointer, + set_checkpointer, + reset_checkpointer, +) from modules.features.chatbot.bridges.tools import ( create_sql_query_tool, create_tavily_search_tool, @@ -168,13 +175,412 @@ class ChatState(BaseModel): plan: Optional[str] = None # Planner routing: "SQL", "TAVILY", "BOTH", "NONE" +@dataclass +class ChatbotGraphContext: + """Per-request context for cached graph execution. Nodes read model/tools from here.""" + + model: AICenterChatModel + planner_model: AICenterChatModel + tools: List[Any] + tools_by_name: dict + sql_tool: Any + tavily_tool: Any + streaming_tool: Any + prompt_sections: dict + system_prompt: str + + +_graph_context: contextvars.ContextVar[Optional[ChatbotGraphContext]] = contextvars.ContextVar( + "chatbot_graph_context", default=None +) + + +def _get_graph_context() -> ChatbotGraphContext: + ctx = _graph_context.get() + if ctx is None: + raise RuntimeError( + "ChatbotGraphContext not set. Ensure graph context is set before invoking cached graph." + ) + return ctx + + +def _set_graph_context(ctx: ChatbotGraphContext) -> contextvars.Token: + return _graph_context.set(ctx) + + +def _reset_graph_context(token: contextvars.Token) -> None: + """Reset graph context. Safe when called from a different async context (e.g. generator cleanup).""" + try: + _graph_context.reset(token) + except ValueError: + # Token was created in a different context (e.g. after yield, generator cleanup) + pass + + +# Cached compiled graph; lock for thread-safe cache access +_compiled_graph_cache: Optional[CompiledStateGraph] = None +_compiled_graph_lock = threading.Lock() + + +def _get_or_build_cached_graph() -> CompiledStateGraph: + """Return cached compiled graph or build and cache it. Thread-safe.""" + global _compiled_graph_cache + with _compiled_graph_lock: + if _compiled_graph_cache is not None: + return _compiled_graph_cache + _compiled_graph_cache = _build_cached_graph() + logger.info("Chatbot: compiled graph cached for reuse") + return _compiled_graph_cache + + +def _build_cached_graph() -> CompiledStateGraph: + """Build the chatbot graph with context-resolved nodes and CheckpointerResolver.""" + checkpointer = CheckpointerResolver() + + PLANNER_SYSTEM = ( + "Du bist ein Assistent. Antworte NUR mit einem Wort: SQL, TAVILY, BOTH oder NONE.\n" + "SQL = Fragen zu Lager, Bestand, Artikel, Preisen, wie viele, Anzahl (Datenbankabfrage).\n" + "TAVILY = Internetsuche, Produktinfos außerhalb der DB, Markttrends.\n" + "BOTH = beides nötig. NONE = nur Begrüßung oder Danksagung, keine Daten nötig.\n" + "Beispiele: 'wie viele X auf Lager' -> SQL, 'Infos zu Produkt Y' -> TAVILY." + ) + SCHEMA_TRUNCATION_SUFFIX = ( + "\n\n[... Schema gekürzt. Wichtige Tabellen: Artikel, Lagerplatz_Artikel, Einkaufspreis, Lagerplatz. " + "Artikel-Spalte: a.\"Artikelbezeichnung\". " + "JOIN: Artikel a, Lagerplatz_Artikel l ON a.I_ID = l.R_ARTIKEL, Lagerplatz lp ON l.R_LAGERPLATZ = lp.I_ID.]" + ) + SQL_PLAN_SUFFIX = ( + "\n\n--- AUSGABEFORMAT (PFLICHT) ---\n" + "Antworte NUR mit einer SQL SELECT-Abfrage in diesem Format:\n" + "```sql\nDEINE_SQL_QUERY\n```\n" + "KRITISCH bei 'wie viele X auf Lager': Liefere ARTIKELZEILEN (Artikelnummer, Artikelbezeichnung, Bestand) " + "mit LIMIT 20, NICHT nur SELECT COUNT(*). Ohne Detailzeilen kann keine Tabelle angezeigt werden. " + "Gesamtanzahl optional via Unterabfrage im SELECT." + ) + FORMULATE_TASK = ( + "\n\n--- AKTUELLE AUFGABE ---\n" + "Du erhältst eine Benutzerfrage und die exakten Datenbankergebnisse. " + "KRITISCH: Nutze NUR die gelieferten Daten. Erfinde NIEMALS Daten (keine LED-A01, LED Rot, etc.). " + "Wenn die Ergebnisse NUR eine Zahl enthalten (z.B. '1. COUNT(*): 806'): Reportiere NUR diese Zahl, KEINE erfundene Tabelle. " + "Eine Tabelle darf NUR erstellt werden, wenn echte Zeilen '1. Spalte: Wert, ...' in den Daten stehen. " + "Beachte die obige ANTWORTSTRUKTUR." + ) + bytes_per_token = 3 + reserved_tokens = 3000 + _SQL_KEYWORDS = ( + "lager", "bestand", "artikel", "wie viele", "anzahl", "preis", + "lieferant", "lieferanten", "bestellen", "verfügbar", "inventar" + ) + + def _get_context_length(ctx: ChatbotGraphContext) -> int: + if hasattr(ctx.model, "_selected_model") and ctx.model._selected_model: + return getattr(ctx.model._selected_model, "contextLength", 128000) + return 128000 + + def _truncate_system_prompt(full_prompt: str, max_chars: int, suffix: str = "") -> str: + if len(full_prompt) <= max_chars: + return full_prompt + return full_prompt[: max_chars - len(suffix)] + suffix + + async def planner_node(state: ChatState) -> dict: + ctx = _get_graph_context() + human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)] + last_human = human_msgs[-1].content if human_msgs else "" + window = [SystemMessage(content=PLANNER_SYSTEM), HumanMessage(content=last_human)] + plan = "SQL" + try: + response = await ctx.planner_model.ainvoke(window) + except ValueError as exc: + if "No suitable model found" in str(exc): + logger.warning(f"Planner model selection failed: {exc}") + return {"plan": plan} + raise + content = (response.content or "").strip().upper() + for keyword in ("SQL", "TAVILY", "BOTH", "NONE"): + if keyword in content: + plan = keyword + break + return {"plan": plan} + + def route_by_plan(state: ChatState) -> str: + ctx = _get_graph_context() + plan = (state.plan or "SQL").upper() + if plan == "NONE" and ctx.sql_tool: + last_user = "" + for m in reversed(state.messages): + if isinstance(m, HumanMessage): + last_user = (m.content or "").lower() + break + if any(kw in last_user for kw in _SQL_KEYWORDS): + logger.info("Planner returned NONE but user asked inventory question - routing to SQL") + plan = "SQL" + if plan in ("SQL", "BOTH") and ctx.sql_tool: + return "agent_sql_plan" + if plan == "TAVILY" and ctx.tavily_tool: + return "agent_tavily" + return "agent_answer" + + def select_window(ctx: ChatbotGraphContext, msgs: List[BaseMessage], max_tokens_override: Optional[int] = None) -> List[BaseMessage]: + def approx_counter(items: List[BaseMessage]) -> int: + return sum(len(getattr(m, "content", "") or "") for m in items) + max_tokens = max_tokens_override or _get_context_length(ctx) + return trim_messages( + msgs, + strategy="last", + token_counter=approx_counter, + max_tokens=int(max_tokens * 0.8), + start_on="human", + end_on=("human", "tool"), + include_system=True, + ) + + async def _agent_common(state: ChatState, system_content: str, llm: Any, node_name: str) -> dict: + ctx = _get_graph_context() + msgs = select_window(ctx, state.messages) + if not msgs or not isinstance(msgs[0], SystemMessage): + window = [SystemMessage(content=system_content)] + msgs + else: + window = [SystemMessage(content=system_content)] + [m for m in msgs if not isinstance(m, SystemMessage)] + try: + response = await llm.ainvoke(window) + except ValueError as exc: + if "No suitable model found" in str(exc): + logger.warning(f"{node_name} model selection failed: {exc}") + response = AIMessage( + content="Es tut mir leid, derzeit steht kein passendes KI-Modell für diese Anfrage zur Verfügung. " + "Bitte versuchen Sie es später erneut oder wenden Sie sich an den Administrator." + ) + else: + raise + return {"messages": [response]} + + def _parse_sql_from_content(content: str) -> Optional[str]: + if not content: + return None + match = re.search(r"```(?:sql)?\s*([\s\S]*?)```", content) + if match: + sql = match.group(1).strip() + if sql and sql.upper().strip().startswith("SELECT"): + return sql + for line in content.split("\n"): + line = line.strip() + if line.upper().startswith("SELECT"): + return line + return None + + def _sanitize_sql_typos(sql: str) -> str: + if not sql: + return sql + sql = re.sub(r"WHEN([A-Za-z_][A-Za-z0-9_.\"]*)", r"WHEN \1", sql, flags=re.IGNORECASE) + sql = re.sub(r"\bLAGerplatz_Artikel\b", "Lagerplatz_Artikel", sql) + sql = re.sub(r"\bLAGerplatz\b", "Lagerplatz", sql) + sql = sql.replace('"Einkaufspreis_neu"', '"Einkaufspreis"') + sql = sql.replace("Einkaufspreis_neu.", "Einkaufspreis.") + sql = re.sub(r'"Einkaufspreis"\."ARTIKEL"', '"Einkaufspreis"."m_Artikel"', sql) + return sql + + async def agent_sql_plan_node(state: ChatState) -> dict: + ctx = _get_graph_context() + ctx_len = _get_context_length(ctx) + max_system_chars = max(1000, int(ctx_len * 0.8 - reserved_tokens) * bytes_per_token) - len(SQL_PLAN_SUFFIX) + schema_part = ctx.prompt_sections.get("schema") or ctx.prompt_sections.get("intro", "") + intro_part = (ctx.prompt_sections.get("intro", "") or "")[:400] + combined = f"{intro_part}\n\n{schema_part}" if intro_part else schema_part + system_content = _truncate_system_prompt(combined, max_system_chars, SCHEMA_TRUNCATION_SUFFIX) + SQL_PLAN_SUFFIX + llm = ctx.model + return await _agent_common(state, system_content, llm, "agent_sql_plan") + + async def parse_execute_sql_node(state: ChatState) -> dict: + ctx = _get_graph_context() + sql_t = ctx.sql_tool + last_msg = state.messages[-1] if state.messages else None + if not isinstance(last_msg, AIMessage): + return {"messages": [ToolMessage(content="Fehler: Keine AI-Antwort zum Parsen.", tool_call_id="parse_0", name="sqlite_query")]} + sql = _parse_sql_from_content(last_msg.content or "") + if not sql or not sql_t: + return {"messages": [ToolMessage(content="Konnte keine SQL-Abfrage aus der Antwort extrahieren.", tool_call_id="parse_0", name="sqlite_query")]} + sql = _sanitize_sql_typos(sql) + try: + result = await sql_t.ainvoke({"query": sql}) + except Exception as e: + logger.error(f"SQL execution failed: {e}") + result = f"Fehler bei der Ausführung: {e}" + return {"messages": [ToolMessage(content=str(result), tool_call_id="parse_0", name="sqlite_query")]} + + async def agent_formulate_node(state: ChatState) -> dict: + ctx = _get_graph_context() + human_content = "" + tool_content = "" + for m in state.messages: + if isinstance(m, HumanMessage): + human_content = m.content or "" + if isinstance(m, ToolMessage) and getattr(m, "name", "") == "sqlite_query": + tool_content = m.content or "" + if not tool_content or not tool_content.strip(): + logger.warning("agent_formulate: no tool_content (sqlite_query) in state.messages") + return {"messages": [AIMessage(content="Die Datenbankabfrage konnte keine Ergebnisse liefern. Bitte versuchen Sie es mit einer anderen Formulierung.")]} + if "Query failed" in tool_content or tool_content.strip().startswith("Error"): + err_summary = "Die Datenbankabfrage ist fehlgeschlagen." + if "no such column" in tool_content: + err_summary += " Ein Spaltenname scheint nicht zu passen. Bitte die Anfrage anders formulieren." + return {"messages": [AIMessage(content=err_summary)]} + formatted_data = _tool_output_to_markdown_table(tool_content) + logger.debug(f"agent_formulate: tool_content length={len(tool_content)}, formatted={len(formatted_data)}") + ctx_len = _get_context_length(ctx) + max_system_chars = max(3000, int(ctx_len * 0.5) * bytes_per_token) - len(FORMULATE_TASK) + resp_struct = ctx.prompt_sections.get("response_structure") or ctx.prompt_sections.get("intro", "") + intro_formulate = ctx.prompt_sections.get("intro", "") + combined = f"{intro_formulate}\n\n{resp_struct}" if intro_formulate != resp_struct else resp_struct + if len(combined) + len(FORMULATE_TASK) > max_system_chars: + combined = _truncate_system_prompt(combined, max_system_chars - len(FORMULATE_TASK), "") + system_content = combined + FORMULATE_TASK + prompt = ( + f"Benutzerfrage: {human_content}\n\n" + "--- VORGEGEBENE DATEN (diese Tabelle/Zahlen UNVERÄNDERT in die Antwort übernehmen): ---\n" + f"{formatted_data}\n\n" + "Die obige Tabelle bzw. Zahlen sind die EINZIGEN erlaubten Daten. Kopiere sie 1:1. " + "Berechne keine eigenen Summen/Anzahlen - nutze die gelieferten Werte. Formuliere die Antwort:" + ) + window = [SystemMessage(content=system_content), HumanMessage(content=prompt)] + try: + response = await ctx.model.ainvoke(window) + except ValueError as exc: + if "No suitable model found" in str(exc): + response = AIMessage(content="Es gab einen Fehler bei der Formulierung. Bitte versuchen Sie es erneut.") + else: + raise + if response.content: + response = AIMessage(content=_sanitize_llm_response(response.content)) + return {"messages": [response]} + + async def agent_tavily_node(state: ChatState) -> dict: + ctx = _get_graph_context() + resp_struct = ctx.prompt_sections.get("response_structure") or "" + intro_tavily = ctx.prompt_sections.get("intro", "") + combined = f"{intro_tavily}\n\n{resp_struct}" if resp_struct else intro_tavily + system_content = _truncate_system_prompt(combined, 6000, "") + tools_tavily = [t for t in [ctx.tavily_tool, ctx.streaming_tool] if t is not None] + llm_tavily = ctx.model.bind_tools(tools=tools_tavily) if tools_tavily else ctx.model + return await _agent_common(state, system_content, llm_tavily, "agent_tavily") + + async def agent_answer_node(state: ChatState) -> dict: + ctx = _get_graph_context() + resp_struct = ctx.prompt_sections.get("response_structure") or "" + intro_answer = ctx.prompt_sections.get("intro", "") + combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer + system_content = _truncate_system_prompt(combined, 6000, "") + llm = ctx.planner_model if ctx.planner_model else ctx.model + return await _agent_common(state, system_content, llm, "agent_answer") + + def should_continue_tavily(state: ChatState) -> str: + last = state.messages[-1] + return "tools" if getattr(last, "tool_calls", None) else END + + def route_back(state: ChatState) -> str: + ctx = _get_graph_context() + return "agent_tavily" if ctx.tavily_tool else "agent_answer" + + async def tools_with_retry(state: ChatState) -> dict: + import asyncio + ctx = _get_graph_context() + last_message = state.messages[-1] + tool_calls = getattr(last_message, "tool_calls", []) + if not tool_calls: + return {"messages": []} + tools_by_name = ctx.tools_by_name + + async def execute_single_tool(tool_call): + tool_name = tool_call.get("name") or tool_call.get("function", {}).get("name") + tool_id = tool_call.get("id", f"call_{tool_name}") + args = tool_call.get("args") or tool_call.get("function", {}).get("arguments", {}) + if isinstance(args, str): + import json + try: + args = json.loads(args) + except Exception: + args = {"input": args} + tool = tools_by_name.get(tool_name) + if not tool: + return ToolMessage(content=f"Error: Tool '{tool_name}' not found", tool_call_id=tool_id, name=tool_name) + try: + if hasattr(tool, "coroutine") and asyncio.iscoroutinefunction(tool.coroutine): + result = await tool.coroutine(**args) + elif hasattr(tool, "ainvoke"): + result = await tool.ainvoke(args) + else: + result = tool.invoke(args) + return ToolMessage(content=str(result), tool_call_id=tool_id, name=tool_name) + except Exception as e: + logger.error(f"Tool {tool_name} failed: {e}") + return ToolMessage(content=f"Error executing {tool_name}: {str(e)}", tool_call_id=tool_id, name=tool_name) + + tool_messages = await asyncio.gather( + *[execute_single_tool(tc) for tc in tool_calls], + return_exceptions=True + ) + result_messages = [] + for i, msg in enumerate(tool_messages): + if isinstance(msg, Exception): + tool_call = tool_calls[i] + tool_name = tool_call.get("name", "unknown") + tool_id = tool_call.get("id", f"call_{i}") + result_messages.append(ToolMessage(content=f"Error: {str(msg)}", tool_call_id=tool_id, name=tool_name)) + else: + result_messages.append(msg) + result = {"messages": result_messages} + no_results_keywords = [ + "returned 0 rows", "no data", "keine artikel gefunden", "keine ergebnisse" + ] + for msg in result.get("messages", []): + content = getattr(msg, "content", "") + if isinstance(content, str): + content_lower = content.lower() + if any(keyword in content_lower for keyword in no_results_keywords): + retry_count = sum(1 for m in state.messages if "retry" in str(getattr(m, "content", "")).lower()) + if retry_count < 2: + logger.info("No results found in tool output, adding retry instruction") + retry_message = HumanMessage( + content="WICHTIG: Die vorherige Suche hat keine Ergebnisse gefunden. " + "Bitte versuche eine alternative Suchstrategie:\n" + "1. Wenn die Frage im Format 'X von Y' war (z.B. 'Lampen von Eaton'), " + "verwende IMMER eine Kombination aus Lieferanten-Filter (WHERE a.\"Lieferant\" LIKE '%Y%') " + "UND Produkttyp-Filter (WHERE a.\"Artikelbezeichnung\" LIKE '%X%' OR ...)\n" + "2. Verwende mehrere Synonyme für den Produkttyp (z.B. bei 'Lampen': Lampe, LED, Beleuchtung, Licht, Leuchte, Strahler)\n" + "3. Führe zuerst eine COUNT-Abfrage durch, dann die Detail-Abfrage mit Lagerbeständen\n" + "4. Verwende LIKE '%Lieferant%' für den Lieferanten-Filter, um auch Varianten zu finden" + ) + result["messages"].append(retry_message) + break + return result + + workflow = StateGraph(ChatState) + workflow.add_node("planner", planner_node) + workflow.add_node("agent_sql_plan", agent_sql_plan_node) + workflow.add_node("parse_execute_sql", parse_execute_sql_node) + workflow.add_node("agent_formulate", agent_formulate_node) + workflow.add_node("tools", tools_with_retry) + workflow.add_node("agent_tavily", agent_tavily_node) + workflow.add_node("agent_answer", agent_answer_node) + workflow.add_edge(START, "planner") + workflow.add_conditional_edges("planner", route_by_plan) + workflow.add_edge("agent_sql_plan", "parse_execute_sql") + workflow.add_edge("parse_execute_sql", "agent_formulate") + workflow.add_edge("agent_formulate", END) + workflow.add_conditional_edges("agent_tavily", should_continue_tavily) + workflow.add_edge("agent_answer", END) + workflow.add_conditional_edges("tools", route_back) + return workflow.compile(checkpointer=checkpointer) + + @dataclass class Chatbot: """Represents a chatbot.""" model: AICenterChatModel memory: DatabaseCheckpointer + planner_model: Optional[AICenterChatModel] = None # Fast model for routing (SQL/TAVILY/NONE) app: CompiledStateGraph = None + _tools: List[Any] = field(default_factory=list) # Configured tools (for cached graph context) system_prompt: str = "You are a helpful assistant." workflow_id: str = "default" config: Optional["ChatbotConfig"] = None @@ -189,6 +595,7 @@ class Chatbot: workflow_id: str = "default", config: Optional["ChatbotConfig"] = None, event_manager=None, + planner_model: Optional[AICenterChatModel] = None, ) -> "Chatbot": """Factory method to create and configure a Chatbot instance. @@ -199,6 +606,7 @@ class Chatbot: workflow_id: The workflow ID (maps to thread_id). config: Optional chatbot configuration for dynamic tool enablement. event_manager: Optional event manager for streaming (passed from route). + planner_model: Optional fast model for planner/routing (default: same as model). Returns: A configured Chatbot instance. @@ -210,9 +618,11 @@ class Chatbot: workflow_id=workflow_id, config=config, _event_manager=event_manager, + planner_model=planner_model, ) configured_tools = await instance._configure_tools() - instance.app = instance._build_app(memory, configured_tools) + instance._tools = configured_tools + instance.app = _get_or_build_cached_graph() return instance async def _configure_tools(self) -> List[Any]: @@ -281,6 +691,7 @@ class Chatbot: tools_sql = [t for t in [sql_tool, tavily_tool, streaming_tool] if t is not None] tools_tavily = [t for t in [tavily_tool, streaming_tool] if t is not None] llm_plain = self.model + llm_planner = self.planner_model if self.planner_model else self.model # SQL path uses structured prompts + parse/execute (no native tool calling) - fits /api/analyze llm_tavily = self.model.bind_tools(tools=tools_tavily) if tools_tavily else self.model @@ -348,7 +759,8 @@ class Chatbot: async def planner_node(state: ChatState) -> dict: """Planner: minimal prompt, no tools. Outputs SQL/TAVILY/BOTH/NONE. - Does NOT add planner message to chat - only sets state.plan for routing.""" + Does NOT add planner message to chat - only sets state.plan for routing. + Uses llm_planner (fast model) when available for lower latency.""" human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)] last_human = human_msgs[-1].content if human_msgs else "" window = [ @@ -357,7 +769,7 @@ class Chatbot: ] plan = "SQL" try: - response = await llm_plain.ainvoke(window) + response = await llm_planner.ainvoke(window) except ValueError as exc: if "No suitable model found" in str(exc): logger.warning(f"Planner model selection failed: {exc}") @@ -555,12 +967,12 @@ class Chatbot: return await _agent_common(state, system_content, llm_tavily, "agent_tavily") async def agent_answer_node(state: ChatState) -> dict: - """Agent with no tools. Uses intro + response_structure.""" + """Agent with no tools (plan NONE). Uses fast model for lower latency.""" resp_struct = _prompt_sections["response_structure"] or "" intro_answer = _prompt_sections["intro"] combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer system_content = _truncate_system_prompt(combined, 6000, "") - return await _agent_common(state, system_content, llm_plain, "agent_answer") + return await _agent_common(state, system_content, llm_planner, "agent_answer") def should_continue_tavily(state: ChatState) -> str: last = state.messages[-1] @@ -722,16 +1134,29 @@ class Chatbot: Returns: The list of messages in the chat history. """ - # Set the right thread ID for memory config = {"configurable": {"thread_id": chat_id}} - - # Single-turn chat (non-streaming) - result = await self.app.ainvoke( - {"messages": [HumanMessage(content=message)]}, config=config + tools_by_name = {t.name: t for t in self._tools} + graph_ctx = ChatbotGraphContext( + model=self.model, + planner_model=self.planner_model or self.model, + tools=self._tools, + tools_by_name=tools_by_name, + sql_tool=tools_by_name.get("sqlite_query"), + tavily_tool=tools_by_name.get("tavily_search"), + streaming_tool=tools_by_name.get("send_streaming_message"), + prompt_sections=_split_system_prompt(self.system_prompt), + system_prompt=self.system_prompt, ) - - # Extract and return the messages from the result - return result["messages"] + ctx_token = _set_graph_context(graph_ctx) + cp_token = set_checkpointer(self.memory) + try: + result = await self.app.ainvoke( + {"messages": [HumanMessage(content=message)]}, config=config + ) + return result["messages"] + finally: + _reset_graph_context(ctx_token) + reset_checkpointer(cp_token) async def stream_events( self, *, message: str, chat_id: str = "default" @@ -757,6 +1182,21 @@ class Chatbot: """Return True if the event is from the root run (v2: empty parent_ids).""" return not ev.get("parent_ids") + # Build tool lookup for cached graph context + tools_by_name = {t.name: t for t in self._tools} + graph_ctx = ChatbotGraphContext( + model=self.model, + planner_model=self.planner_model or self.model, + tools=self._tools, + tools_by_name=tools_by_name, + sql_tool=tools_by_name.get("sqlite_query"), + tavily_tool=tools_by_name.get("tavily_search"), + streaming_tool=tools_by_name.get("send_streaming_message"), + prompt_sections=_split_system_prompt(self.system_prompt), + system_prompt=self.system_prompt, + ) + ctx_token = _set_graph_context(graph_ctx) + cp_token = set_checkpointer(self.memory) try: async for event in self.app.astream_events( {"messages": [HumanMessage(content=message)]}, @@ -767,6 +1207,25 @@ class Chatbot: ename = event.get("name") or "" edata = event.get("data") or {} + # Stream LLM tokens for ChatGPT-like incremental display + if etype in ("on_llm_stream", "on_chat_model_stream"): + ch = edata.get("chunk") + if ch is None: + continue + # Chunk can be string, AIMessageChunk (has .content), or dict + content = "" + if isinstance(ch, str): + content = ch + elif hasattr(ch, "content"): + content = ch.content or "" + if isinstance(content, list): + content = "".join(str(x) for x in content) + elif isinstance(ch, dict): + content = ch.get("content", "") or "" + if isinstance(content, str) and content: + yield {"type": "chunk", "content": content} + continue + # Stream human-readable progress via the special send_streaming_message tool # Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py) if etype == "on_tool_start": @@ -833,3 +1292,6 @@ class Chatbot: # Emit a single error envelope and end the stream logger.error(f"Exception in stream_events: {exc}", exc_info=True) yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"} + finally: + _reset_graph_context(ctx_token) + reset_checkpointer(cp_token) diff --git a/modules/features/chatbot/interfaceFeatureChatbot.py b/modules/features/chatbot/interfaceFeatureChatbot.py index 15711e7e..741bf05f 100644 --- a/modules/features/chatbot/interfaceFeatureChatbot.py +++ b/modules/features/chatbot/interfaceFeatureChatbot.py @@ -397,8 +397,8 @@ class ChatObjects: dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) - # Create database connector directly - self.db = DatabaseConnector( + from modules.connectors.connectorDbPostgre import _get_cached_connector + self.db = _get_cached_connector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, @@ -769,6 +769,72 @@ class ChatObjects: """Backward-compat alias: workflowId maps to conversationId.""" return self.getConversation(workflowId) + def getWorkflowMinimal(self, workflowId: str) -> Optional[ChatbotConversation]: + """Lightweight fetch: conversation record only, no logs/messages. For resume path.""" + conversations = getRecordsetWithRBAC( + self.db, + ChatbotConversation, + self.currentUser, + recordFilter={"id": workflowId}, + mandateId=self.mandateId, + featureInstanceId=self.featureInstanceId, + featureCode="chatbot", + ) + if not conversations: + return None + conv = conversations[0] + max_steps = conv.get("maxSteps") + return ChatbotConversation( + id=conv["id"], + featureInstanceId=conv.get("featureInstanceId") or self.featureInstanceId or "", + name=conv.get("name"), + status=conv.get("status", "running"), + currentRound=conv.get("currentRound", 0) or 0, + lastActivity=conv.get("lastActivity", getUtcTimestamp()), + startedAt=conv.get("startedAt", getUtcTimestamp()), + workflowMode=ChatbotWorkflowModeEnum(conv.get("workflowMode", "Chatbot")), + maxSteps=max_steps if max_steps is not None else 10, + logs=[], + messages=[], + ) + + def getMessageCount(self, conversationId: str) -> int: + """Returns message count for a conversation (single query, no document fetch).""" + messages = getRecordsetWithRBAC( + self.db, + ChatbotMessage, + self.currentUser, + recordFilter={"conversationId": conversationId}, + mandateId=self.mandateId, + featureInstanceId=self.featureInstanceId, + featureCode="chatbot", + ) + return len(messages) if messages else 0 + + def updateWorkflowMinimal( + self, workflowId: str, workflowData: Dict[str, Any] + ) -> ChatbotConversation: + """Lightweight update: no logs/messages fetch. For resume when caller has minimal workflow.""" + if not self.checkRbacPermission(ChatbotConversation, "update", workflowId): + raise PermissionError(f"No permission to update conversation {workflowId}") + simpleFields, _ = self._separateObjectFields(ChatbotConversation, workflowData) + simpleFields["lastActivity"] = getUtcTimestamp() + updated = self.db.recordModify(ChatbotConversation, workflowId, simpleFields) + max_steps = updated.get("maxSteps") + return ChatbotConversation( + id=updated["id"], + featureInstanceId=updated.get("featureInstanceId") or self.featureInstanceId or "", + name=updated.get("name"), + status=updated.get("status", "running"), + currentRound=updated.get("currentRound", 0) or 0, + lastActivity=updated.get("lastActivity", getUtcTimestamp()), + startedAt=updated.get("startedAt", getUtcTimestamp()), + workflowMode=ChatbotWorkflowModeEnum(updated.get("workflowMode", "Chatbot")), + maxSteps=max_steps if max_steps is not None else 10, + logs=[], + messages=[], + ) + def createConversation(self, conversationData: Dict[str, Any]) -> ChatbotConversation: """Creates a new conversation if user has permission.""" if not self.checkRbacPermission(ChatbotConversation, "create"): @@ -825,9 +891,7 @@ class ChatObjects: updated = self.db.recordModify(ChatbotConversation, conversationId, simpleFields) - logs = self.getLogs(conversationId) - messages = self.getMessages(conversationId) - + # Reuse logs/messages from conv — update only touches simple fields, not related data return ChatbotConversation( id=updated["id"], featureInstanceId=updated.get("featureInstanceId") or conv.featureInstanceId or self.featureInstanceId or "", @@ -838,8 +902,8 @@ class ChatObjects: startedAt=updated.get("startedAt", conv.startedAt), workflowMode=ChatbotWorkflowModeEnum(updated.get("workflowMode", conv.workflowMode.value)), maxSteps=updated.get("maxSteps") if updated.get("maxSteps") is not None else conv.maxSteps, - logs=logs, - messages=messages + logs=conv.logs, + messages=conv.messages ) def updateWorkflow(self, workflowId: str, workflowData: Dict[str, Any]) -> ChatbotConversation: @@ -955,11 +1019,13 @@ class ChatObjects: if pagination and pagination.sort: messageDicts = self._applySorting(messageDicts, pagination.sort) - # If no pagination requested, return all items + # If no pagination requested, return all items (batch-fetch documents to avoid N+1) if pagination is None: + msg_ids = [m["id"] for m in messageDicts] + docs_by_message = self.getDocumentsForMessages(msg_ids) if msg_ids else {} chat_messages = [] for msg in messageDicts: - documents = self.getDocuments(msg["id"]) + documents = docs_by_message.get(msg["id"], []) chat_message = ChatbotMessage( id=msg["id"], conversationId=msg["conversationId"], @@ -994,10 +1060,11 @@ class ChatObjects: startIdx = (pagination.page - 1) * pagination.pageSize endIdx = startIdx + pagination.pageSize pagedMessageDicts = messageDicts[startIdx:endIdx] - + paged_msg_ids = [m["id"] for m in pagedMessageDicts] + docs_by_message = self.getDocumentsForMessages(paged_msg_ids) if paged_msg_ids else {} chat_messages = [] for msg in pagedMessageDicts: - documents = self.getDocuments(msg["id"]) + documents = docs_by_message.get(msg["id"], []) chat_message = ChatbotMessage( id=msg["id"], conversationId=msg["conversationId"], @@ -1224,6 +1291,17 @@ class ChatObjects: logger.error(f"Error updating message {messageId}: {str(e)}", exc_info=True) raise ValueError(f"Error updating message {messageId}: {str(e)}") + def createStat(self, statData: Dict[str, Any]): + """Create stat record. Compatibility with ChatService; stats may not be persisted in chatbot schema.""" + from modules.datamodels.datamodelChat import ChatStat + stat = ChatStat(**statData) + try: + created = self.db.recordCreate(ChatStat, statData) + return ChatStat(**created) + except Exception as e: + logger.debug(f"createStat: not persisting (chatbot schema): {e}") + return stat + def deleteMessage(self, conversationId: str, messageId: str) -> bool: """Deletes a conversation message and related data if user has access.""" try: @@ -1308,6 +1386,30 @@ class ChatObjects: logger.error(f"Error getting message documents: {str(e)}") return [] + def getDocumentsForMessages(self, messageIds: List[str]) -> Dict[str, List[ChatbotDocument]]: + """Returns documents for multiple messages in one query. Returns {messageId: [doc, ...]}.""" + if not messageIds: + return {} + try: + documents = getRecordsetWithRBAC( + self.db, + ChatbotDocument, + self.currentUser, + recordFilter={"messageId": messageIds}, + mandateId=self.mandateId, + featureInstanceId=self.featureInstanceId, + featureCode="chatbot", + ) + result: Dict[str, List[ChatbotDocument]] = {mid: [] for mid in messageIds} + for doc in documents: + mid = doc.get("messageId") + if mid in result: + result[mid].append(ChatbotDocument(**doc)) + return result + except Exception as e: + logger.error(f"Error getting documents for messages: {e}") + return {mid: [] for mid in messageIds} + def createDocument(self, documentData: Dict[str, Any]) -> ChatbotDocument: """Creates a document for a message in normalized table.""" try: @@ -1451,11 +1553,14 @@ class ChatObjects: items = [] messages = getRecordsetWithRBAC(self.db, ChatbotMessage, self.currentUser, recordFilter={"conversationId": conversationId}, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode="chatbot") + # Batch-fetch documents for all messages (avoids N+1) + message_ids = [m["id"] for m in messages if afterTimestamp is None or parseTimestamp(m.get("publishedAt"), default=getUtcTimestamp()) > afterTimestamp] + docs_by_message = self.getDocumentsForMessages(message_ids) if message_ids else {} for msg in messages: msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp()) if afterTimestamp is not None and msgTimestamp <= afterTimestamp: continue - documents = self.getDocuments(msg["id"]) + documents = docs_by_message.get(msg["id"], []) chatMessage = ChatbotMessage( id=msg["id"], conversationId=msg["conversationId"], diff --git a/modules/features/chatbot/mainChatbot.py b/modules/features/chatbot/mainChatbot.py index 031056e9..766130b2 100644 --- a/modules/features/chatbot/mainChatbot.py +++ b/modules/features/chatbot/mainChatbot.py @@ -6,7 +6,7 @@ Handles feature initialization and RBAC catalog registration. """ import logging -from typing import Dict, List, Any +from typing import Dict, List, Any, Optional logger = logging.getLogger(__name__) @@ -48,6 +48,14 @@ RESOURCE_OBJECTS = [ }, ] +# Service requirements for chatbot — resolved via service center +REQUIRED_SERVICES = [ + {"serviceKey": "chat", "meta": {"usage": "File info, document handling"}}, + {"serviceKey": "ai", "meta": {"usage": "AI calls, conversation name generation"}}, + {"serviceKey": "billing", "meta": {"usage": "Usage tracking, balance checks"}}, + {"serviceKey": "streaming", "meta": {"usage": "Event manager, ChatStreamingHelper"}}, +] + # Template roles for this feature # Role names MUST follow convention: {featureCode}-{roleName} TEMPLATE_ROLES = [ @@ -170,6 +178,76 @@ def registerFeature(catalogService) -> bool: return False +def getChatbotServices( + user, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None, + workflow=None, +) -> "_ChatbotServiceHub": + """ + Get lightweight service hub for chatbot (chat, ai, streaming) without loading + the full legacy Services hub. Avoids ~90 ms from _loadFeatureInterfaces + + _loadFeatureServices; only instantiates required services. + Uses interfaceFeatureChatbot (ChatObjects) for interfaceDbChat to avoid + duplicate DB init - chatProcess reuses hub.interfaceDbChat. + """ + from modules.services import PublicService + from modules.interfaces.interfaceDbApp import getInterface as getAppInterface + from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface + from modules.services.serviceChat.mainServiceChat import ChatService + from modules.services.serviceAi.mainServiceAi import AiService + from modules.services.serviceStreaming.mainServiceStreaming import StreamingService + + hub = _ChatbotServiceHub() + hub.user = user + hub.mandateId = mandateId + hub.featureInstanceId = featureInstanceId + hub.workflow = workflow + hub.featureCode = "chatbot" + hub.allowedProviders = None + + hub.interfaceDbApp = getAppInterface(user, mandateId=mandateId) + # interfaceDbComponent: lazy-loaded on first access (saves ~100–300 ms when no file uploads) + hub._interfaceDbComponent_val = None + # Use ChatObjects (interfaceFeatureChatbot) - same as chatProcess, avoids extra interfaceDbChat init + hub.interfaceDbChat = getChatbotInterface( + user, mandateId=mandateId, featureInstanceId=featureInstanceId + ) + + hub.chat = PublicService(ChatService(hub)) + hub.ai = PublicService(AiService(hub), functionsOnly=False) + hub.streaming = PublicService(StreamingService(hub)) + + return hub + + +class _ChatbotServiceHub: + """Lightweight hub with chat, ai, streaming for chatbot; avoids full Services init.""" + + user = None + mandateId = None + featureInstanceId = None + workflow = None + interfaceDbApp = None + _interfaceDbComponent_val = None + interfaceDbChat = None + + @property + def interfaceDbComponent(self): + """Lazy-load interfaceDbComponent on first access (saves ~100–300 ms when no files).""" + if self._interfaceDbComponent_val is None: + from modules.interfaces.interfaceDbManagement import getInterface as getComponentInterface + self._interfaceDbComponent_val = getComponentInterface( + self.user, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId + ) + return self._interfaceDbComponent_val + chat = None + ai = None + streaming = None + featureCode = "chatbot" + allowedProviders = None + + def _syncTemplateRolesToDb() -> int: """ Sync template roles and their AccessRules to the database. diff --git a/modules/features/chatbot/routeFeatureChatbot.py b/modules/features/chatbot/routeFeatureChatbot.py index d8231a07..b85b45bc 100644 --- a/modules/features/chatbot/routeFeatureChatbot.py +++ b/modules/features/chatbot/routeFeatureChatbot.py @@ -33,6 +33,17 @@ from modules.features.chatbot.interfaceFeatureChatbot import ChatbotConversation from modules.features.chatbot import chatProcess from modules.services.serviceStreaming import get_event_manager +# Pre-warm AI connectors when this router loads (before first request). +# Ensures connectors are ready; avoids 4–8 s delay on first chatbot message. +try: + import modules.aicore.aicoreModelRegistry # noqa: F401 + from modules.aicore.aicoreModelRegistry import modelRegistry + modelRegistry.ensureConnectorsRegistered() + modelRegistry.refreshModels(force=True) + logging.getLogger(__name__).info("Chatbot router: AI connectors pre-warmed") +except Exception as e: + logging.getLogger(__name__).warning(f"Chatbot AI pre-warm failed: {e}") + # Configure logger logger = logging.getLogger(__name__) @@ -43,12 +54,14 @@ router = APIRouter( responses={404: {"description": "Not found"}} ) -def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None): - """Get chatbot interface with instance context.""" - mandateId = str(context.mandateId) if context.mandateId else None +def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None, mandateId: Optional[str] = None): + """Get chatbot interface with instance context. + Pass mandateId when available (e.g. from _validateInstanceAccess) to ensure cache hit with getChatbotServices. + """ + effective_mandate = mandateId if mandateId is not None else (str(context.mandateId) if context.mandateId else None) return interfaceDbChat.getInterface( - context.user, - mandateId=mandateId, + context.user, + mandateId=effective_mandate, featureInstanceId=instanceId ) @@ -125,7 +138,7 @@ def get_chatbot_threads( mandateId = _validateInstanceAccess(instanceId, context) try: - interfaceDbChat = _getServiceChat(context, instanceId) + interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId) if workflowId: workflow = interfaceDbChat.getWorkflow(workflowId) @@ -266,12 +279,14 @@ async def stream_chatbot_start( async def event_stream(): """Async generator for SSE events - pure event-driven streaming (no polling).""" try: - # Get interface for initial data and status checks - interfaceDbChat = _getServiceChat(context, instanceId) - - # Get current workflow to check if resuming and get current round - current_workflow = interfaceDbChat.getWorkflow(workflow.id) - current_round = current_workflow.currentRound if current_workflow else None + # Yield keepalive immediately so client gets 200 + first byte fast (normal chatbot feel) + yield ": keepalive\n\n" + + # Use same mandateId as chatProcess so we hit interface cache (avoid duplicate DB init) + interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId) + + # Use workflow from chatProcess (no refetch) + current_round = workflow.currentRound if workflow else None is_resuming = final_workflow_id is not None and current_round and current_round > 1 # Send initial chat data (exact format as chatData endpoint) - only once at start @@ -358,7 +373,7 @@ async def stream_chatbot_start( event_type = event.get("type") event_data = event.get("data", {}) - # Emit chatdata events (messages, logs, stats, status) in exact chatData format + # Emit chatdata events (messages, logs, stats, status, chunk) in exact chatData format if event_type == "chatdata" and event_data: # Handle status events (transient UI feedback) if event_data.get("type") == "status": @@ -368,6 +383,13 @@ async def stream_chatbot_start( "label": event_data.get("label", "") } yield f"data: {json.dumps(status_item)}\n\n" + elif event_data.get("type") == "chunk": + # Token chunks for ChatGPT-like streaming + chunk_item = { + "type": "chunk", + "content": event_data.get("content", "") + } + yield f"data: {json.dumps(chunk_item)}\n\n" else: # Emit other chatdata items (messages, logs, stats) in exact chatData format chatdata_item = event_data @@ -444,7 +466,7 @@ async def stop_chatbot( try: # Get chatbot interface with instance context - interfaceDbChat = _getServiceChat(context, instanceId) + interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId) # Get workflow to verify it exists and belongs to this instance workflow = interfaceDbChat.getWorkflow(workflowId) @@ -519,8 +541,7 @@ def delete_chatbot( mandateId = _validateInstanceAccess(instanceId, context) try: - # Get service center - interfaceDbChat = _getServiceChat(context, instanceId) + interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId) # Get workflow directly (interface already handles mandate filtering) workflow = interfaceDbChat.getWorkflow(workflowId) diff --git a/modules/features/chatbot/service.py b/modules/features/chatbot/service.py index c419b623..a1c2343c 100644 --- a/modules/features/chatbot/service.py +++ b/modules/features/chatbot/service.py @@ -19,7 +19,7 @@ from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, ProcessingModeEnum from modules.datamodels.datamodelDocref import DocumentReferenceList, DocumentItemReference from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp -from modules.services import getInterface as getServices +from modules.features.chatbot.mainChatbot import getChatbotServices from modules.features.chatbot.chatbot import Chatbot from modules.features.chatbot.bridges.ai import AICenterChatModel, clear_workflow_allowed_providers from modules.features.chatbot.bridges.memory import DatabaseCheckpointer @@ -91,33 +91,28 @@ async def chatProcess( ChatbotConversation instance """ try: - # Get services with mandate and feature instance context - services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId) + # Get services from service center (only chat, ai, billing, streaming — avoids ~90ms legacy hub) + services = getChatbotServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId) services.featureCode = 'chatbot' - - # Load instance config and apply allowedProviders for AI calls (conversation name + main chat) - chatbot_config = await _load_chatbot_config(featureInstanceId) - if chatbot_config.model.allowedProviders: - services.allowedProviders = chatbot_config.model.allowedProviders - logger.info(f"Chatbot instance {featureInstanceId}: restricting to providers {chatbot_config.model.allowedProviders}") - - from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface - interfaceDbChat = getChatbotInterface(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId) + + # Config and model warm run in background task — return stream ~2–3 s faster for normal feel + chatbot_config = None + + # Reuse hub's interfaceDbChat (ChatObjects) - avoids duplicate DB init + interfaceDbChat = services.interfaceDbChat # Create or load workflow (event_manager passed from route) if workflowId: - workflow = interfaceDbChat.getWorkflow(workflowId) + # Lightweight resume: minimal fetch + minimal update (no logs/messages load) + workflow = interfaceDbChat.getWorkflowMinimal(workflowId) if not workflow: raise ValueError(f"Workflow {workflowId} not found") - - # Resume workflow: increment round number new_round = workflow.currentRound + 1 - interfaceDbChat.updateWorkflow(workflowId, { + workflow = interfaceDbChat.updateWorkflowMinimal(workflowId, { "status": "running", "currentRound": new_round, "lastActivity": getUtcTimestamp() }) - workflow = interfaceDbChat.getWorkflow(workflowId) logger.info(f"Resumed workflow {workflowId}, round incremented to {new_round}") # Create event queue if it doesn't exist (for streaming) @@ -166,10 +161,7 @@ async def chatProcess( # Create event queue for new workflow (for streaming) event_manager.create_queue(workflow.id) - - # Reload workflow to get current message count - workflow = interfaceDbChat.getWorkflow(workflow.id) - + # Process uploaded files and create ChatbotDocuments user_documents = [] if userInput.listFileId and len(userInput.listFileId) > 0: @@ -203,14 +195,19 @@ async def chatProcess( except Exception as e: logger.error(f"Error processing file ID {fileId}: {e}", exc_info=True) - # Store user message + # Store user message (sequenceNr: for resume use message count, else len+1) + seq_nr = ( + interfaceDbChat.getMessageCount(workflow.id) + 1 + if workflowId + else len(workflow.messages) + 1 + ) userMessageData: Dict[str, Any] = { "id": f"msg_{uuid.uuid4()}", "conversationId": workflow.id, "message": userInput.prompt, "role": "user", "status": "first" if workflowId is None else "step", - "sequenceNr": len(workflow.messages) + 1, + "sequenceNr": seq_nr, "publishedAt": getUtcTimestamp(), "roundNumber": workflow.currentRound, "taskNumber": 0, @@ -223,11 +220,12 @@ async def chatProcess( userMessage = interfaceDbChat.createMessage(userMessageData, event_manager=None) logger.info(f"Stored user message: {userMessage.id} with {len(user_documents)} document(s)") - # Update workflow status - interfaceDbChat.updateWorkflow(workflow.id, { - "status": "running", - "lastActivity": getUtcTimestamp() - }) + # Update workflow status (minimal update for lastActivity; resume already did this) + if not workflowId: + interfaceDbChat.updateWorkflowMinimal(workflow.id, { + "status": "running", + "lastActivity": getUtcTimestamp() + }) # Pre-flight billing check before starting LangGraph (if mandateId present) if mandateId: @@ -244,9 +242,6 @@ async def chatProcess( config=chatbot_config, event_manager=event_manager )) - - # Reload workflow to include new message - workflow = interfaceDbChat.getWorkflow(workflow.id) return workflow except Exception as e: @@ -728,8 +723,7 @@ async def _convert_file_ids_to_document_references( if not document_id: try: from modules.interfaces.interfaceRbac import getRecordsetWithRBAC - from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface - chatbotInterface = getChatbotInterface(services.user, mandateId=services.mandateId, featureInstanceId=services.featureInstanceId) + chatbotInterface = services.interfaceDbChat documents = getRecordsetWithRBAC( chatbotInterface.db, ChatbotDocument, @@ -928,6 +922,20 @@ async def _bridge_chatbot_events( step="status" ) continue + + # Handle token chunks for ChatGPT-like streaming (append to message as it's generated) + if event_type == "chunk": + content = event.get("content", "") + if content: + await event_manager.emit_event( + context_id=workflow_id, + event_type="chatdata", + data={"type": "chunk", "content": content}, + event_category="chat", + message="Token chunk", + step="chunk" + ) + continue # Handle final response if event_type == "final": @@ -1117,9 +1125,59 @@ async def _bridge_chatbot_events( clear_workflow_allowed_providers(workflow_id) +def _load_chatbot_config_sync(featureInstanceId: Optional[str]) -> ChatbotConfig: + """ + Load chatbot configuration from FeatureInstance (database). Sync version for use in executor. + + Args: + featureInstanceId: Feature instance ID to load config from + + Returns: + ChatbotConfig instance + """ + if not featureInstanceId: + raise ValueError("featureInstanceId is required to load chatbot config") + + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.interfaces.interfaceFeatures import getFeatureInterface + + rootInterface = getRootInterface() + featureInterface = getFeatureInterface(rootInterface.db) + instance = featureInterface.getFeatureInstance(featureInstanceId) + + if not instance: + raise ValueError(f"FeatureInstance {featureInstanceId} not found") + + logger.info(f"Loading chatbot config from FeatureInstance {featureInstanceId}") + return load_chatbot_config_from_instance(instance) + + +def _warm_model_registry_cache_sync( + currentUser: "User", + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None, +) -> None: + """ + Pre-warm getAvailableModels cache so planner/agent model selection is a cache hit. + Uses mandateId/featureInstanceId for faster RBAC (fewer roles to load). + Runs in executor to avoid blocking event loop. + """ + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.aicore.aicoreModelRegistry import modelRegistry + + root = getRootInterface() + modelRegistry.getAvailableModels( + currentUser=currentUser, + rbacInstance=root.rbac, + mandateId=mandateId, + featureInstanceId=featureInstanceId, + ) + + async def _load_chatbot_config(featureInstanceId: Optional[str]) -> ChatbotConfig: """ Load chatbot configuration from FeatureInstance (database). + Runs in thread pool to avoid blocking event loop on DB I/O. Args: featureInstanceId: Feature instance ID to load config from @@ -1134,20 +1192,7 @@ async def _load_chatbot_config(featureInstanceId: Optional[str]) -> ChatbotConfi raise ValueError("featureInstanceId is required to load chatbot config") try: - # Import here to avoid circular imports - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.interfaces.interfaceFeatures import getFeatureInterface - - # Get feature instance from database - rootInterface = getRootInterface() - featureInterface = getFeatureInterface(rootInterface.db) - instance = featureInterface.getFeatureInstance(featureInstanceId) - - if not instance: - raise ValueError(f"FeatureInstance {featureInstanceId} not found") - - logger.info(f"Loading chatbot config from FeatureInstance {featureInstanceId}") - return load_chatbot_config_from_instance(instance) + return await asyncio.to_thread(_load_chatbot_config_sync, featureInstanceId) except ValueError: raise except Exception as e: @@ -1250,9 +1295,25 @@ async def _processChatbotMessageLangGraph( event_manager: Event manager for streaming (passed from chatProcess) """ try: - from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface - interfaceDbChat = getChatbotInterface(currentUser, mandateId=services.mandateId, featureInstanceId=featureInstanceId) - + # Start config + model cache warm in parallel (planner/agent need cache hit to avoid 2–3 s per call) + config_task = asyncio.create_task(_load_chatbot_config(featureInstanceId)) if config is None else None + warm_task = asyncio.create_task(asyncio.to_thread( + _warm_model_registry_cache_sync, currentUser, services.mandateId, featureInstanceId + )) + + # Emit first status immediately so stream feels responsive + await event_manager.emit_event( + context_id=workflowId, + event_type="chatdata", + data={"type": "status", "label": "Starte..."}, + event_category="chat", + message="Status update", + step="status", + ) + + # Reuse interfaceDbChat from services (ChatObjects) - avoids duplicate DB init + interfaceDbChat = services.interfaceDbChat + # Reload workflow to get current messages workflow = interfaceDbChat.getWorkflow(workflowId) if not workflow: @@ -1272,19 +1333,10 @@ async def _processChatbotMessageLangGraph( logger.info(f"Workflow {workflowId} was stopped, aborting processing") return - # Emit synthetic status for real-time UI feedback - await event_manager.emit_event( - context_id=workflowId, - event_type="chatdata", - data={"type": "status", "label": "Lade Konfiguration..."}, - event_category="chat", - message="Status update", - step="status" - ) - - # Load configuration if not passed (e.g. when resuming) - if config is None: - config = await _load_chatbot_config(featureInstanceId) + # Await config and model cache warm (planner gets cache hit, saves ~2–3 s) + if config_task is not None: + config = await config_task + await warm_task # Replace {{DATE}} placeholder in system prompt from datetime import datetime @@ -1310,25 +1362,30 @@ async def _processChatbotMessageLangGraph( processing_mode=processing_mode, billing_callback=billing_callback, workflow_id=workflowId, - allowed_providers=allowed_providers + allowed_providers=allowed_providers, + mandate_id=services.mandateId, + feature_instance_id=featureInstanceId, + ) + # Fast planner model (gpt-4o-mini etc.) for routing - saves ~1-2 s on first response + planner_model = AICenterChatModel( + user=currentUser, + operation_type=operation_type, + processing_mode=processing_mode, + billing_callback=billing_callback, + workflow_id=workflowId, + allowed_providers=allowed_providers, + prefer_fast_model=True, + mandate_id=services.mandateId, + feature_instance_id=featureInstanceId, ) - # Emit synthetic status for real-time UI feedback - await event_manager.emit_event( - context_id=workflowId, - event_type="chatdata", - data={"type": "status", "label": "Bereite Chat vor..."}, - event_category="chat", - message="Status update", - step="status" - ) - - # Create memory/checkpointer (uses chatbot's own DB via interfaceFeatureChatbot) + # Create memory/checkpointer (reuse interface to avoid extra DB init) memory = DatabaseCheckpointer( user=currentUser, workflow_id=workflowId, mandateId=services.mandateId, - featureInstanceId=featureInstanceId + featureInstanceId=featureInstanceId, + interface=interfaceDbChat, ) # Create chatbot instance with config for dynamic tool configuration @@ -1338,7 +1395,8 @@ async def _processChatbotMessageLangGraph( system_prompt=system_prompt, workflow_id=workflowId, config=config, - event_manager=event_manager + event_manager=event_manager, + planner_model=planner_model, ) # Emit synthetic status for real-time UI feedback diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index c7e4f8bf..3283c577 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -15,7 +15,7 @@ from typing import Dict, Any, List, Optional, Union from passlib.context import CryptContext import uuid -from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector from modules.shared.configuration import APP_CONFIG from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp from modules.interfaces.interfaceBootstrap import initBootstrap @@ -144,8 +144,7 @@ class AppObjects: dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) - # Create database connector directly - self.db = DatabaseConnector( + self.db = _get_cached_connector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py index 61e32886..e065bf6d 100644 --- a/modules/interfaces/interfaceDbManagement.py +++ b/modules/interfaces/interfaceDbManagement.py @@ -12,7 +12,7 @@ import hashlib import math from typing import Dict, Any, List, Optional, Union -from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector from modules.interfaces.interfaceRbac import getRecordsetWithRBAC from modules.security.rbac import RbacClass from modules.datamodels.datamodelRbac import AccessRuleContext @@ -131,8 +131,7 @@ class ComponentObjects: dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) - # Create database connector directly - self.db = DatabaseConnector( + self.db = _get_cached_connector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, diff --git a/modules/security/rbac.py b/modules/security/rbac.py index f660e419..f1d83252 100644 --- a/modules/security/rbac.py +++ b/modules/security/rbac.py @@ -11,7 +11,7 @@ Multi-Tenant Design: """ import logging -from typing import List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel from modules.datamodels.datamodelMembership import ( @@ -163,6 +163,52 @@ class RbacClass: return permissions + def checkResourceAccessBulk( + self, + user: User, + resourcePaths: List[str], + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None + ) -> Dict[str, bool]: + """ + Check view access for multiple RESOURCE paths in one pass. + Uses same logic as getUserPermissions but batches DB access. + Returns {path: has_view}. + """ + result = {p: False for p in resourcePaths} + if not resourcePaths: + return result + # SysAdmin bypass + if hasattr(user, "isSysAdmin") and user.isSysAdmin: + return {p: True for p in resourcePaths} + roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId) + if not roleIds: + return result + rulesWithPriority = self._getRulesForRoleIds( + roleIds, AccessRuleContext.RESOURCE, mandateId, featureInstanceId + ) + for path in resourcePaths: + rolePermissions = {} + for priority, rule in rulesWithPriority: + if not self._ruleMatchesItem(rule, path): + continue + roleId = rule.roleId + itemSpecificity = self._getItemSpecificity(rule, path) + if roleId not in rolePermissions: + rolePermissions[roleId] = (priority, itemSpecificity, rule) + else: + existingPriority, existingSpecificity, _ = rolePermissions[roleId] + if priority > existingPriority or ( + priority == existingPriority and itemSpecificity > existingSpecificity + ): + rolePermissions[roleId] = (priority, itemSpecificity, rule) + highestPriority = max((p for p, _, _ in rolePermissions.values()), default=0) + for _, (priority, _, rule) in rolePermissions.items(): + if priority >= highestPriority and rule.view: + result[path] = True + break + return result + def _getRoleIdsForUser( self, user: User,