fix:performance improvements
- app.py: Pre-warm AI connectors at module load and in lifespan - aicoreModelRegistry.py: Connector discovery cache, getAvailableModels cache, bulk RBAC, eager prewarm - connectorDbPostgre.py: Connector cache, contextvars for userId, eviction (max 32) - chatbot: Uses _get_cached_connector, Service center integration, BillingService exceptions, BillingService exceptions instead of direct imports - interfaceDbApp.py: Uses _get_cached_connector - interfaceDbManagement.py: Uses _get_cached_connector - security/rbac.py: Adds checkResourceAccessBulk
This commit is contained in:
parent
42e79a724a
commit
6dc2afafb9
13 changed files with 1317 additions and 192 deletions
19
app.py
19
app.py
|
|
@ -280,12 +280,29 @@ initLogging()
|
|||
logger = logging.getLogger(__name__)
|
||||
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||
|
||||
# Pre-warm AI connectors on process load (before lifespan). Critical for chatbot latency.
|
||||
try:
|
||||
import modules.aicore.aicoreModelRegistry # noqa: F401
|
||||
logger.info("AI connectors pre-warm (app load) triggered")
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"AI pre-warm at app load failed: {e}")
|
||||
|
||||
# Define lifespan context manager for application startup/shutdown events
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Application is starting up")
|
||||
|
||||
# --- Pre-warm AI connectors FIRST (before any other startup work) ---
|
||||
# Avoids 4–8 s latency on first chatbot request; must run before first use.
|
||||
try:
|
||||
import modules.aicore.aicoreModelRegistry # noqa: F401 - triggers eager pre-warm
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
modelRegistry.ensureConnectorsRegistered()
|
||||
modelRegistry.refreshModels(force=True)
|
||||
logger.info("AI connectors and model registry pre-warmed")
|
||||
except Exception as e:
|
||||
logger.warning(f"AI pre-warm failed: {e}")
|
||||
|
||||
# Bootstrap database if needed (creates initial users, mandates, roles, etc.)
|
||||
# This must happen before getting root interface
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
|
|
@ -333,7 +350,7 @@ async def lifespan(app: FastAPI):
|
|||
# Register audit log cleanup scheduler
|
||||
from modules.shared.auditLogger import registerAuditLogCleanupScheduler
|
||||
registerAuditLogCleanupScheduler()
|
||||
|
||||
|
||||
# Ensure billing settings and accounts exist for all mandates
|
||||
try:
|
||||
from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ Implements plugin-like architecture for connector discovery.
|
|||
import logging
|
||||
import importlib
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from modules.datamodels.datamodelAi import AiModel
|
||||
from .aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelUam import User
|
||||
|
|
@ -31,6 +32,9 @@ class ModelRegistry:
|
|||
self._lastRefresh: Optional[float] = None
|
||||
self._refreshInterval: float = 300.0 # 5 minutes
|
||||
self._connectorsInitialized: bool = False
|
||||
self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call
|
||||
self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts)
|
||||
self._getAvailableModelsCacheTtl: float = 30.0 # seconds
|
||||
|
||||
def registerConnector(self, connector: BaseConnectorAi):
|
||||
"""Register a connector and collect its models."""
|
||||
|
|
@ -68,34 +72,38 @@ class ModelRegistry:
|
|||
raise
|
||||
|
||||
def discoverConnectors(self) -> List[BaseConnectorAi]:
|
||||
"""Auto-discover connectors by scanning aicorePlugin*.py files."""
|
||||
"""Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use."""
|
||||
if self._discoveredConnectorsCache is not None:
|
||||
return self._discoveredConnectorsCache
|
||||
|
||||
connectors = []
|
||||
connectorDir = os.path.dirname(__file__)
|
||||
|
||||
|
||||
# Scan for connector files
|
||||
for filename in os.listdir(connectorDir):
|
||||
if filename.startswith('aicorePlugin') and filename.endswith('.py'):
|
||||
moduleName = filename[:-3] # Remove .py extension
|
||||
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f'modules.aicore.{moduleName}')
|
||||
|
||||
|
||||
# Find connector classes (classes that inherit from BaseConnectorAi)
|
||||
for attrName in dir(module):
|
||||
attr = getattr(module, attrName)
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, BaseConnectorAi) and
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, BaseConnectorAi) and
|
||||
attr != BaseConnectorAi):
|
||||
|
||||
|
||||
# Instantiate the connector
|
||||
connector = attr()
|
||||
connectors.append(connector)
|
||||
logger.info(f"Discovered connector: {connector.getConnectorType()}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to discover connector from {filename}: {e}")
|
||||
|
||||
|
||||
self._discoveredConnectorsCache = connectors
|
||||
return connectors
|
||||
|
||||
def ensureConnectorsRegistered(self):
|
||||
|
|
@ -175,24 +183,49 @@ class ModelRegistry:
|
|||
self.refreshModels()
|
||||
return [model for model in self._models.values() if model.priority == priority]
|
||||
|
||||
def getAvailableModels(self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> List[AiModel]:
|
||||
def getAvailableModels(
|
||||
self,
|
||||
currentUser: Optional[User] = None,
|
||||
rbacInstance: Optional[RbacClass] = None,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None
|
||||
) -> List[AiModel]:
|
||||
"""Get only available models, optionally filtered by RBAC permissions.
|
||||
Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call.
|
||||
|
||||
Args:
|
||||
currentUser: Optional user object for RBAC filtering
|
||||
rbacInstance: Optional RBAC instance for permission checks
|
||||
mandateId: Optional mandate context for faster RBAC (loads fewer roles)
|
||||
featureInstanceId: Optional feature instance for RBAC context
|
||||
|
||||
Returns:
|
||||
List of available models (filtered by RBAC if user provided)
|
||||
"""
|
||||
self.refreshModels()
|
||||
cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0)
|
||||
now = time.time()
|
||||
if cache_key in self._getAvailableModelsCache:
|
||||
cached_models, cached_ts = self._getAvailableModelsCache[cache_key]
|
||||
if now - cached_ts < self._getAvailableModelsCacheTtl:
|
||||
logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...")
|
||||
return cached_models
|
||||
|
||||
allModels = list(self._models.values())
|
||||
availableModels = [model for model in allModels if model.isAvailable]
|
||||
|
||||
# Apply RBAC filtering if user and RBAC instance provided
|
||||
|
||||
# Apply RBAC filtering if user and RBAC instance provided (batch check for performance)
|
||||
if currentUser and rbacInstance:
|
||||
availableModels = self._filterModelsByRbac(availableModels, currentUser, rbacInstance)
|
||||
|
||||
availableModels = self._filterModelsByRbac(
|
||||
availableModels, currentUser, rbacInstance, mandateId, featureInstanceId
|
||||
)
|
||||
|
||||
self._getAvailableModelsCache[cache_key] = (availableModels, now)
|
||||
# Prune expired entries to avoid unbounded growth
|
||||
expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl]
|
||||
for k in expired:
|
||||
del self._getAvailableModelsCache[k]
|
||||
|
||||
unavailableCount = len(allModels) - len(availableModels)
|
||||
if unavailableCount > 0:
|
||||
unavailableModels = [m.name for m in allModels if not m.isAvailable]
|
||||
|
|
@ -200,32 +233,33 @@ class ModelRegistry:
|
|||
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
|
||||
return availableModels
|
||||
|
||||
def _filterModelsByRbac(self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass) -> List[AiModel]:
|
||||
"""Filter models based on RBAC permissions.
|
||||
|
||||
Args:
|
||||
models: List of models to filter
|
||||
currentUser: Current user object
|
||||
rbacInstance: RBAC instance for permission checks
|
||||
|
||||
Returns:
|
||||
Filtered list of models that user has access to
|
||||
"""
|
||||
def _filterModelsByRbac(
|
||||
self,
|
||||
models: List[AiModel],
|
||||
currentUser: User,
|
||||
rbacInstance: RbacClass,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None
|
||||
) -> List[AiModel]:
|
||||
"""Filter models based on RBAC permissions. Uses bulk check for performance."""
|
||||
paths = []
|
||||
model_paths = {} # model -> (connector_path, model_path)
|
||||
for model in models:
|
||||
connector_path = f"ai.model.{model.connectorType}"
|
||||
model_path = f"ai.model.{model.connectorType}.{model.displayName}"
|
||||
paths.extend([connector_path, model_path])
|
||||
model_paths[id(model)] = (connector_path, model_path)
|
||||
# Single bulk RBAC call instead of 2*N per-model calls
|
||||
access = rbacInstance.checkResourceAccessBulk(
|
||||
currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId
|
||||
)
|
||||
filteredModels = []
|
||||
for model in models:
|
||||
# Check access at both connector level and model level
|
||||
connectorResourcePath = f"ai.model.{model.connectorType}"
|
||||
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
|
||||
|
||||
# User needs access to either connector (all models) or specific model
|
||||
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
|
||||
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
|
||||
|
||||
if hasConnectorAccess or hasModelAccess:
|
||||
connector_path, model_path = model_paths[id(model)]
|
||||
if access.get(connector_path, False) or access.get(model_path, False):
|
||||
filteredModels.append(model)
|
||||
else:
|
||||
logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})")
|
||||
|
||||
return filteredModels
|
||||
|
||||
def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]:
|
||||
|
|
@ -305,3 +339,17 @@ class ModelRegistry:
|
|||
|
||||
# Global registry instance
|
||||
modelRegistry = ModelRegistry()
|
||||
|
||||
# Eager pre-warm on first import: ensures connectors are ready in this process.
|
||||
# Critical for chatbot performance — avoids 4–8 s latency on first request.
|
||||
# Runs when this module is first imported (lifespan or first chatbot request).
|
||||
def _eager_prewarm() -> None:
|
||||
try:
|
||||
modelRegistry.ensureConnectorsRegistered()
|
||||
modelRegistry.refreshModels(force=True)
|
||||
logger.info("AI connectors and model registry pre-warmed (module load)")
|
||||
except Exception as e:
|
||||
logger.warning(f"AI eager pre-warm skipped: {e}")
|
||||
|
||||
|
||||
_eager_prewarm()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
import contextvars
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import logging
|
||||
|
|
@ -99,7 +100,56 @@ def _get_model_fields(model_class) -> Dict[str, str]:
|
|||
return fields
|
||||
|
||||
|
||||
# No caching needed with proper database
|
||||
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
|
||||
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
|
||||
# contextvars to avoid races when concurrent requests share the same connector.
|
||||
_MAX_CACHED_CONNECTORS = 32
|
||||
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
|
||||
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
|
||||
_connector_cache_lock = threading.Lock()
|
||||
_current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
||||
"db_connector_user_id", default=None
|
||||
)
|
||||
|
||||
|
||||
def _get_cached_connector(
|
||||
dbHost: str,
|
||||
dbDatabase: str,
|
||||
dbUser: str = None,
|
||||
dbPassword: str = None,
|
||||
dbPort: int = None,
|
||||
userId: str = None,
|
||||
) -> "DatabaseConnector":
|
||||
"""Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
|
||||
Uses contextvars for userId so concurrent requests sharing the same connector get correct _createdBy/_modifiedBy.
|
||||
"""
|
||||
port = int(dbPort) if dbPort is not None else 5432
|
||||
key = (dbHost, dbDatabase, port)
|
||||
with _connector_cache_lock:
|
||||
if key not in _connector_cache:
|
||||
# Evict oldest if at capacity
|
||||
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
|
||||
oldest_key = _connector_cache_order.pop(0)
|
||||
if oldest_key in _connector_cache:
|
||||
try:
|
||||
_connector_cache[oldest_key].close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing evicted connector: {e}")
|
||||
del _connector_cache[oldest_key]
|
||||
_connector_cache[key] = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=userId,
|
||||
)
|
||||
_connector_cache_order.append(key)
|
||||
conn = _connector_cache[key]
|
||||
# Set request-scoped userId via contextvar (avoids mutating shared connector)
|
||||
if userId is not None:
|
||||
_current_user_id.set(userId)
|
||||
return conn
|
||||
|
||||
|
||||
class DatabaseConnector:
|
||||
|
|
@ -645,24 +695,22 @@ class DatabaseConnector:
|
|||
if "id" in record and str(record["id"]) != recordId:
|
||||
raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}")
|
||||
|
||||
# Add metadata
|
||||
# Add metadata - use contextvar for request-scoped userId when sharing connector
|
||||
effective_user_id = _current_user_id.get()
|
||||
if effective_user_id is None:
|
||||
effective_user_id = self.userId
|
||||
currentTime = getUtcTimestamp()
|
||||
# Set _createdAt and _createdBy if this is a new record (record doesn't have _createdAt)
|
||||
if "_createdAt" not in record:
|
||||
record["_createdAt"] = currentTime
|
||||
# Only set _createdBy if userId is valid (not None or empty string)
|
||||
if self.userId:
|
||||
record["_createdBy"] = self.userId
|
||||
# No warning - empty userId is normal during bootstrap
|
||||
# Also ensure _createdBy is set even if _createdAt exists but _createdBy is missing/empty
|
||||
if effective_user_id:
|
||||
record["_createdBy"] = effective_user_id
|
||||
elif "_createdBy" not in record or not record.get("_createdBy"):
|
||||
if self.userId:
|
||||
record["_createdBy"] = self.userId
|
||||
# No warning - empty userId is normal during bootstrap
|
||||
# Always update modification metadata
|
||||
if effective_user_id:
|
||||
record["_createdBy"] = effective_user_id
|
||||
record["_modifiedAt"] = currentTime
|
||||
if self.userId:
|
||||
record["_modifiedBy"] = self.userId
|
||||
if effective_user_id:
|
||||
record["_modifiedBy"] = effective_user_id
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
self._save_record(cursor, table, recordId, record, model_class)
|
||||
|
|
@ -782,12 +830,13 @@ class DatabaseConnector:
|
|||
return False
|
||||
|
||||
def updateContext(self, userId: str) -> None:
|
||||
"""Updates the context of the database connector."""
|
||||
"""Updates the context of the database connector.
|
||||
Sets both instance userId and contextvar for request-scoped use when connector is shared.
|
||||
"""
|
||||
if userId is None:
|
||||
raise ValueError("userId must be provided")
|
||||
|
||||
self.userId = userId
|
||||
# No cache to clear - database handles data consistency
|
||||
_current_user_id.set(userId)
|
||||
|
||||
# Public API
|
||||
|
||||
|
|
|
|||
|
|
@ -38,9 +38,10 @@ from modules.datamodels.datamodelUam import User
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Workflow-level store for allowed_providers (survives LangGraph/bind_tools execution context
|
||||
# where instance attributes may be lost when model is wrapped or serialized)
|
||||
# Workflow-level store for allowed_providers and RBAC context (survives LangGraph/bind_tools
|
||||
# execution context where instance attributes may be lost when model is wrapped or serialized)
|
||||
_workflow_allowed_providers: Dict[str, List[str]] = {}
|
||||
_workflow_rbac_context: Dict[str, tuple] = {} # workflow_id -> (mandateId, featureInstanceId)
|
||||
|
||||
|
||||
def clear_workflow_allowed_providers(workflow_id: str) -> None:
|
||||
|
|
@ -62,11 +63,14 @@ class AICenterChatModel(BaseChatModel):
|
|||
billing_callback: Optional[Callable[[AiCallResponse], None]] = None,
|
||||
workflow_id: Optional[str] = None,
|
||||
allowed_providers: Optional[List[str]] = None,
|
||||
prefer_fast_model: bool = False,
|
||||
mandate_id: Optional[str] = None,
|
||||
feature_instance_id: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the AI center chat model bridge.
|
||||
|
||||
|
||||
Args:
|
||||
user: Current user for RBAC and model selection
|
||||
operation_type: Operation type for model selection
|
||||
|
|
@ -74,6 +78,7 @@ class AICenterChatModel(BaseChatModel):
|
|||
billing_callback: Optional callback invoked after each _agenerate with AiCallResponse for billing
|
||||
workflow_id: Optional workflow/conversation ID for billing context
|
||||
allowed_providers: Optional list of allowed provider connector types (empty/None = all)
|
||||
prefer_fast_model: When True, strongly prefer faster models (e.g. gpt-4o-mini for planner)
|
||||
**kwargs: Additional arguments passed to BaseChatModel
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -85,9 +90,14 @@ class AICenterChatModel(BaseChatModel):
|
|||
object.__setattr__(self, "_billing_callback", billing_callback)
|
||||
object.__setattr__(self, "_workflow_id", workflow_id)
|
||||
object.__setattr__(self, "_allowed_providers", allowed_providers or [])
|
||||
object.__setattr__(self, "_prefer_fast_model", prefer_fast_model)
|
||||
object.__setattr__(self, "_mandate_id", mandate_id)
|
||||
object.__setattr__(self, "_feature_instance_id", feature_instance_id)
|
||||
# Store in workflow-level registry so it survives when instance attrs are lost (e.g. bind_tools)
|
||||
if workflow_id and allowed_providers:
|
||||
_workflow_allowed_providers[workflow_id] = list(allowed_providers)
|
||||
if workflow_id and (mandate_id is not None or feature_instance_id is not None):
|
||||
_workflow_rbac_context[workflow_id] = (mandate_id, feature_instance_id)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
|
@ -129,17 +139,25 @@ class AICenterChatModel(BaseChatModel):
|
|||
# Get available models with RBAC filtering
|
||||
# Use cached/singleton interfaces for better performance
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
|
||||
|
||||
workflow_id = getattr(self, "_workflow_id", None)
|
||||
rootInterface = getRootInterface()
|
||||
rbac_instance = rootInterface.rbac
|
||||
|
||||
|
||||
mandate_id = getattr(self, "_mandate_id", None)
|
||||
feature_instance_id = getattr(self, "_feature_instance_id", None)
|
||||
if workflow_id and (mandate_id is None and feature_instance_id is None):
|
||||
ctx = _workflow_rbac_context.get(workflow_id)
|
||||
if ctx:
|
||||
mandate_id, feature_instance_id = ctx
|
||||
available_models = modelRegistry.getAvailableModels(
|
||||
currentUser=self.user,
|
||||
rbacInstance=rbac_instance
|
||||
rbacInstance=rbac_instance,
|
||||
mandateId=mandate_id,
|
||||
featureInstanceId=feature_instance_id,
|
||||
)
|
||||
|
||||
# Allowed providers: instance attr or workflow store (lost in LangGraph/bind_tools context)
|
||||
workflow_id = getattr(self, '_workflow_id', None)
|
||||
allowed = (
|
||||
(_workflow_allowed_providers.get(workflow_id) if workflow_id else None)
|
||||
or getattr(self, '_allowed_providers', None)
|
||||
|
|
@ -155,7 +173,8 @@ class AICenterChatModel(BaseChatModel):
|
|||
options = AiCallOptions(
|
||||
operationType=self.operation_type,
|
||||
processingMode=self.processing_mode,
|
||||
allowedProviders=allowed if allowed else None
|
||||
allowedProviders=allowed if allowed else None,
|
||||
preferFastModel=getattr(self, "_prefer_fast_model", False),
|
||||
)
|
||||
|
||||
# Select model
|
||||
|
|
@ -246,7 +265,97 @@ class AICenterChatModel(BaseChatModel):
|
|||
|
||||
# Run the async method synchronously
|
||||
return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs))
|
||||
|
||||
|
||||
async def _call_openai_streaming(
|
||||
self,
|
||||
ai_messages: List[dict],
|
||||
run_manager: Optional[Any],
|
||||
model_call: "AiModelCall",
|
||||
input_bytes: int,
|
||||
start_time: float,
|
||||
) -> "AiModelResponse":
|
||||
"""Call OpenAI/Ollama with stream=True, emit tokens via run_manager, return full response."""
|
||||
import httpx
|
||||
import json as _json
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
if self._selected_model.connectorType == "openai":
|
||||
api_url = getattr(self._selected_model, "apiUrl", None) or "https://api.openai.com/v1/chat/completions"
|
||||
api_key = APP_CONFIG.get("Connector_AiOpenai_API_SECRET")
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key not configured")
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
ollama_model = self._selected_model.name
|
||||
else:
|
||||
base_url = getattr(self._selected_model, "apiUrl", "").replace("/api/analyze", "")
|
||||
api_url = f"{base_url.rstrip('/')}/v1/chat/completions"
|
||||
api_key = APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["X-API-Key"] = api_key
|
||||
ollama_model = getattr(self._selected_model, "version", None) or self._selected_model.name
|
||||
|
||||
payload = {
|
||||
"model": ollama_model,
|
||||
"messages": ai_messages,
|
||||
"temperature": self._selected_model.temperature,
|
||||
"max_tokens": self._selected_model.maxTokens,
|
||||
"stream": True,
|
||||
}
|
||||
content_parts: List[str] = []
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
async with client.stream("POST", api_url, headers=headers, json=payload) as resp:
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"OpenAI stream error: {resp.status_code} - {await resp.aread()}")
|
||||
buffer = ""
|
||||
async for chunk in resp.aiter_text():
|
||||
buffer += chunk
|
||||
while "\n" in buffer or "\r\n" in buffer:
|
||||
line, _, buffer = buffer.partition("\n")
|
||||
line = line.strip()
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = _json.loads(data_str)
|
||||
choices = data.get("choices") or []
|
||||
if choices:
|
||||
delta = choices[0].get("delta") or {}
|
||||
token = delta.get("content") or ""
|
||||
if token and run_manager and hasattr(run_manager, "on_llm_new_token"):
|
||||
run_manager.on_llm_new_token(token)
|
||||
content_parts.append(token)
|
||||
except _json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
content = "".join(content_parts)
|
||||
processing_time = time.time() - start_time
|
||||
output_bytes = len(content.encode("utf-8"))
|
||||
price_chf = 0.0
|
||||
if getattr(self._selected_model, "calculatepriceCHF", None):
|
||||
try:
|
||||
price_chf = self._selected_model.calculatepriceCHF(processing_time, input_bytes, output_bytes)
|
||||
except Exception:
|
||||
pass
|
||||
billing_callback = getattr(self, "_billing_callback", None)
|
||||
if billing_callback:
|
||||
try:
|
||||
billing_callback(AiCallResponse(
|
||||
content=content,
|
||||
modelName=self._selected_model.name,
|
||||
provider=self._selected_model.connectorType or "unknown",
|
||||
priceCHF=price_chf,
|
||||
processingTime=processing_time,
|
||||
bytesSent=input_bytes,
|
||||
bytesReceived=output_bytes,
|
||||
errorCount=0,
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Billing callback error: {e}")
|
||||
|
||||
return AiModelResponse(content=content, success=True, modelId=self._selected_model.name, metadata={})
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
|
@ -484,6 +593,11 @@ class AICenterChatModel(BaseChatModel):
|
|||
"tool_calls": tool_calls,
|
||||
},
|
||||
)
|
||||
elif not tools and self._selected_model.connectorType in ("openai", "privatellm"):
|
||||
# Streaming path for OpenAI/Ollama without tools (ChatGPT-like token streaming)
|
||||
response = await self._call_openai_streaming(
|
||||
ai_messages, run_manager, model_call, input_bytes, start_time
|
||||
)
|
||||
else:
|
||||
# No tools or not OpenAI - use connector normally
|
||||
if not self._selected_model.functionCall:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Custom LangGraph checkpointer using existing database interface.
|
|||
Maps LangGraph state to existing message storage format.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
|
||||
|
|
@ -47,19 +48,30 @@ class DatabaseCheckpointer(BaseCheckpointSaver):
|
|||
Maps LangGraph thread_id to conversation.id; stores messages via interface (workflowId maps to conversationId).
|
||||
"""
|
||||
|
||||
def __init__(self, user: User, workflow_id: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
user: User,
|
||||
workflow_id: str,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
*,
|
||||
interface=None,
|
||||
):
|
||||
"""
|
||||
Initialize the database checkpointer.
|
||||
|
||||
|
||||
Args:
|
||||
user: Current user for database access
|
||||
workflow_id: Workflow ID (maps to LangGraph thread_id)
|
||||
mandateId: Mandate ID for proper data isolation
|
||||
featureInstanceId: Feature instance ID for proper data isolation
|
||||
interface: Optional pre-created chatbot interface (avoids extra getInterface + DB init)
|
||||
"""
|
||||
self.user = user
|
||||
self.workflow_id = workflow_id
|
||||
self.interface = getChatbotInterface(user, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
self.interface = interface if interface is not None else getChatbotInterface(
|
||||
user, mandateId=mandateId, featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
def _convert_langchain_to_db_message(
|
||||
self,
|
||||
|
|
@ -445,3 +457,120 @@ class DatabaseCheckpointer(BaseCheckpointSaver):
|
|||
# Not implemented - using aput() instead
|
||||
# This method is called by LangGraph but we handle writes through aput()
|
||||
pass
|
||||
|
||||
|
||||
# ContextVar for per-request checkpointer (used by CheckpointerResolver for graph caching)
|
||||
_current_checkpointer: contextvars.ContextVar[Optional[BaseCheckpointSaver]] = contextvars.ContextVar(
|
||||
"chatbot_current_checkpointer", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_checkpointer(checkpointer: BaseCheckpointSaver) -> contextvars.Token:
|
||||
"""Set the current request's checkpointer. Returns token to reset later."""
|
||||
return _current_checkpointer.set(checkpointer)
|
||||
|
||||
|
||||
def reset_checkpointer(token: contextvars.Token) -> None:
|
||||
"""Reset checkpointer to prior value. Safe when called from a different async context."""
|
||||
try:
|
||||
_current_checkpointer.reset(token)
|
||||
except ValueError:
|
||||
# Token was created in a different context (e.g. after yield, generator cleanup)
|
||||
pass
|
||||
|
||||
|
||||
class CheckpointerResolver(BaseCheckpointSaver):
|
||||
"""
|
||||
Delegating checkpointer that reads the real checkpointer from context.
|
||||
Used for graph caching: the compiled graph uses this resolver; at invoke time
|
||||
the per-request checkpointer is set via set_checkpointer().
|
||||
"""
|
||||
|
||||
def _get_checkpointer(self) -> BaseCheckpointSaver:
|
||||
cp = _current_checkpointer.get()
|
||||
if cp is None:
|
||||
raise RuntimeError(
|
||||
"CheckpointerResolver: no checkpointer in context. "
|
||||
"Call set_checkpointer() before invoking the cached graph."
|
||||
)
|
||||
return cp
|
||||
|
||||
def put(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: Dict[str, int],
|
||||
) -> None:
|
||||
self._get_checkpointer().put(config, checkpoint, metadata, new_versions)
|
||||
|
||||
def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]:
|
||||
return self._get_checkpointer().get(config)
|
||||
|
||||
def list(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[Checkpoint]:
|
||||
return self._get_checkpointer().list(config, filter, before, limit)
|
||||
|
||||
def put_writes(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
writes: List[Tuple[str, Any]],
|
||||
task_id: str,
|
||||
) -> None:
|
||||
self._get_checkpointer().put_writes(config, writes, task_id)
|
||||
|
||||
async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]:
|
||||
inner = self._get_checkpointer()
|
||||
if hasattr(inner, "aget_tuple"):
|
||||
return await inner.aget_tuple(config)
|
||||
checkpoint = inner.get(config)
|
||||
if checkpoint:
|
||||
metadata: CheckpointMetadata = {"step": 0}
|
||||
return CheckpointTuple(
|
||||
config=config,
|
||||
checkpoint=checkpoint,
|
||||
metadata=metadata,
|
||||
parent_config=None,
|
||||
pending_writes=None,
|
||||
)
|
||||
return None
|
||||
|
||||
async def aput(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: Dict[str, int],
|
||||
) -> None:
|
||||
inner = self._get_checkpointer()
|
||||
if hasattr(inner, "aput"):
|
||||
await inner.aput(config, checkpoint, metadata, new_versions)
|
||||
else:
|
||||
inner.put(config, checkpoint, metadata, new_versions)
|
||||
|
||||
async def alist(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[Checkpoint]:
|
||||
inner = self._get_checkpointer()
|
||||
if hasattr(inner, "alist"):
|
||||
return await inner.alist(config, filter, before, limit)
|
||||
return inner.list(config, filter, before, limit)
|
||||
|
||||
async def aput_writes(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
writes: List[Tuple[str, Any]],
|
||||
task_id: str,
|
||||
) -> None:
|
||||
inner = self._get_checkpointer()
|
||||
if hasattr(inner, "aput_writes"):
|
||||
await inner.aput_writes(config, writes, task_id)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
# All rights reserved.
|
||||
"""Chatbot domain logic."""
|
||||
|
||||
import contextvars
|
||||
import re
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Annotated, AsyncIterator, Any, List, Optional, TYPE_CHECKING
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -21,7 +23,12 @@ from langgraph.graph import StateGraph, START, END
|
|||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from modules.features.chatbot.bridges.ai import AICenterChatModel
|
||||
from modules.features.chatbot.bridges.memory import DatabaseCheckpointer
|
||||
from modules.features.chatbot.bridges.memory import (
|
||||
CheckpointerResolver,
|
||||
DatabaseCheckpointer,
|
||||
set_checkpointer,
|
||||
reset_checkpointer,
|
||||
)
|
||||
from modules.features.chatbot.bridges.tools import (
|
||||
create_sql_query_tool,
|
||||
create_tavily_search_tool,
|
||||
|
|
@ -168,13 +175,412 @@ class ChatState(BaseModel):
|
|||
plan: Optional[str] = None # Planner routing: "SQL", "TAVILY", "BOTH", "NONE"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatbotGraphContext:
|
||||
"""Per-request context for cached graph execution. Nodes read model/tools from here."""
|
||||
|
||||
model: AICenterChatModel
|
||||
planner_model: AICenterChatModel
|
||||
tools: List[Any]
|
||||
tools_by_name: dict
|
||||
sql_tool: Any
|
||||
tavily_tool: Any
|
||||
streaming_tool: Any
|
||||
prompt_sections: dict
|
||||
system_prompt: str
|
||||
|
||||
|
||||
_graph_context: contextvars.ContextVar[Optional[ChatbotGraphContext]] = contextvars.ContextVar(
|
||||
"chatbot_graph_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def _get_graph_context() -> ChatbotGraphContext:
|
||||
ctx = _graph_context.get()
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"ChatbotGraphContext not set. Ensure graph context is set before invoking cached graph."
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def _set_graph_context(ctx: ChatbotGraphContext) -> contextvars.Token:
|
||||
return _graph_context.set(ctx)
|
||||
|
||||
|
||||
def _reset_graph_context(token: contextvars.Token) -> None:
|
||||
"""Reset graph context. Safe when called from a different async context (e.g. generator cleanup)."""
|
||||
try:
|
||||
_graph_context.reset(token)
|
||||
except ValueError:
|
||||
# Token was created in a different context (e.g. after yield, generator cleanup)
|
||||
pass
|
||||
|
||||
|
||||
# Cached compiled graph; lock for thread-safe cache access
|
||||
_compiled_graph_cache: Optional[CompiledStateGraph] = None
|
||||
_compiled_graph_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_or_build_cached_graph() -> CompiledStateGraph:
|
||||
"""Return cached compiled graph or build and cache it. Thread-safe."""
|
||||
global _compiled_graph_cache
|
||||
with _compiled_graph_lock:
|
||||
if _compiled_graph_cache is not None:
|
||||
return _compiled_graph_cache
|
||||
_compiled_graph_cache = _build_cached_graph()
|
||||
logger.info("Chatbot: compiled graph cached for reuse")
|
||||
return _compiled_graph_cache
|
||||
|
||||
|
||||
def _build_cached_graph() -> CompiledStateGraph:
|
||||
"""Build the chatbot graph with context-resolved nodes and CheckpointerResolver."""
|
||||
checkpointer = CheckpointerResolver()
|
||||
|
||||
PLANNER_SYSTEM = (
|
||||
"Du bist ein Assistent. Antworte NUR mit einem Wort: SQL, TAVILY, BOTH oder NONE.\n"
|
||||
"SQL = Fragen zu Lager, Bestand, Artikel, Preisen, wie viele, Anzahl (Datenbankabfrage).\n"
|
||||
"TAVILY = Internetsuche, Produktinfos außerhalb der DB, Markttrends.\n"
|
||||
"BOTH = beides nötig. NONE = nur Begrüßung oder Danksagung, keine Daten nötig.\n"
|
||||
"Beispiele: 'wie viele X auf Lager' -> SQL, 'Infos zu Produkt Y' -> TAVILY."
|
||||
)
|
||||
SCHEMA_TRUNCATION_SUFFIX = (
|
||||
"\n\n[... Schema gekürzt. Wichtige Tabellen: Artikel, Lagerplatz_Artikel, Einkaufspreis, Lagerplatz. "
|
||||
"Artikel-Spalte: a.\"Artikelbezeichnung\". "
|
||||
"JOIN: Artikel a, Lagerplatz_Artikel l ON a.I_ID = l.R_ARTIKEL, Lagerplatz lp ON l.R_LAGERPLATZ = lp.I_ID.]"
|
||||
)
|
||||
SQL_PLAN_SUFFIX = (
|
||||
"\n\n--- AUSGABEFORMAT (PFLICHT) ---\n"
|
||||
"Antworte NUR mit einer SQL SELECT-Abfrage in diesem Format:\n"
|
||||
"```sql\nDEINE_SQL_QUERY\n```\n"
|
||||
"KRITISCH bei 'wie viele X auf Lager': Liefere ARTIKELZEILEN (Artikelnummer, Artikelbezeichnung, Bestand) "
|
||||
"mit LIMIT 20, NICHT nur SELECT COUNT(*). Ohne Detailzeilen kann keine Tabelle angezeigt werden. "
|
||||
"Gesamtanzahl optional via Unterabfrage im SELECT."
|
||||
)
|
||||
FORMULATE_TASK = (
|
||||
"\n\n--- AKTUELLE AUFGABE ---\n"
|
||||
"Du erhältst eine Benutzerfrage und die exakten Datenbankergebnisse. "
|
||||
"KRITISCH: Nutze NUR die gelieferten Daten. Erfinde NIEMALS Daten (keine LED-A01, LED Rot, etc.). "
|
||||
"Wenn die Ergebnisse NUR eine Zahl enthalten (z.B. '1. COUNT(*): 806'): Reportiere NUR diese Zahl, KEINE erfundene Tabelle. "
|
||||
"Eine Tabelle darf NUR erstellt werden, wenn echte Zeilen '1. Spalte: Wert, ...' in den Daten stehen. "
|
||||
"Beachte die obige ANTWORTSTRUKTUR."
|
||||
)
|
||||
bytes_per_token = 3
|
||||
reserved_tokens = 3000
|
||||
_SQL_KEYWORDS = (
|
||||
"lager", "bestand", "artikel", "wie viele", "anzahl", "preis",
|
||||
"lieferant", "lieferanten", "bestellen", "verfügbar", "inventar"
|
||||
)
|
||||
|
||||
def _get_context_length(ctx: ChatbotGraphContext) -> int:
|
||||
if hasattr(ctx.model, "_selected_model") and ctx.model._selected_model:
|
||||
return getattr(ctx.model._selected_model, "contextLength", 128000)
|
||||
return 128000
|
||||
|
||||
def _truncate_system_prompt(full_prompt: str, max_chars: int, suffix: str = "") -> str:
|
||||
if len(full_prompt) <= max_chars:
|
||||
return full_prompt
|
||||
return full_prompt[: max_chars - len(suffix)] + suffix
|
||||
|
||||
async def planner_node(state: ChatState) -> dict:
|
||||
ctx = _get_graph_context()
|
||||
human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)]
|
||||
last_human = human_msgs[-1].content if human_msgs else ""
|
||||
window = [SystemMessage(content=PLANNER_SYSTEM), HumanMessage(content=last_human)]
|
||||
plan = "SQL"
|
||||
try:
|
||||
response = await ctx.planner_model.ainvoke(window)
|
||||
except ValueError as exc:
|
||||
if "No suitable model found" in str(exc):
|
||||
logger.warning(f"Planner model selection failed: {exc}")
|
||||
return {"plan": plan}
|
||||
raise
|
||||
content = (response.content or "").strip().upper()
|
||||
for keyword in ("SQL", "TAVILY", "BOTH", "NONE"):
|
||||
if keyword in content:
|
||||
plan = keyword
|
||||
break
|
||||
return {"plan": plan}
|
||||
|
||||
def route_by_plan(state: ChatState) -> str:
|
||||
ctx = _get_graph_context()
|
||||
plan = (state.plan or "SQL").upper()
|
||||
if plan == "NONE" and ctx.sql_tool:
|
||||
last_user = ""
|
||||
for m in reversed(state.messages):
|
||||
if isinstance(m, HumanMessage):
|
||||
last_user = (m.content or "").lower()
|
||||
break
|
||||
if any(kw in last_user for kw in _SQL_KEYWORDS):
|
||||
logger.info("Planner returned NONE but user asked inventory question - routing to SQL")
|
||||
plan = "SQL"
|
||||
if plan in ("SQL", "BOTH") and ctx.sql_tool:
|
||||
return "agent_sql_plan"
|
||||
if plan == "TAVILY" and ctx.tavily_tool:
|
||||
return "agent_tavily"
|
||||
return "agent_answer"
|
||||
|
||||
def select_window(ctx: ChatbotGraphContext, msgs: List[BaseMessage], max_tokens_override: Optional[int] = None) -> List[BaseMessage]:
|
||||
def approx_counter(items: List[BaseMessage]) -> int:
|
||||
return sum(len(getattr(m, "content", "") or "") for m in items)
|
||||
max_tokens = max_tokens_override or _get_context_length(ctx)
|
||||
return trim_messages(
|
||||
msgs,
|
||||
strategy="last",
|
||||
token_counter=approx_counter,
|
||||
max_tokens=int(max_tokens * 0.8),
|
||||
start_on="human",
|
||||
end_on=("human", "tool"),
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
async def _agent_common(state: ChatState, system_content: str, llm: Any, node_name: str) -> dict:
|
||||
ctx = _get_graph_context()
|
||||
msgs = select_window(ctx, state.messages)
|
||||
if not msgs or not isinstance(msgs[0], SystemMessage):
|
||||
window = [SystemMessage(content=system_content)] + msgs
|
||||
else:
|
||||
window = [SystemMessage(content=system_content)] + [m for m in msgs if not isinstance(m, SystemMessage)]
|
||||
try:
|
||||
response = await llm.ainvoke(window)
|
||||
except ValueError as exc:
|
||||
if "No suitable model found" in str(exc):
|
||||
logger.warning(f"{node_name} model selection failed: {exc}")
|
||||
response = AIMessage(
|
||||
content="Es tut mir leid, derzeit steht kein passendes KI-Modell für diese Anfrage zur Verfügung. "
|
||||
"Bitte versuchen Sie es später erneut oder wenden Sie sich an den Administrator."
|
||||
)
|
||||
else:
|
||||
raise
|
||||
return {"messages": [response]}
|
||||
|
||||
def _parse_sql_from_content(content: str) -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
match = re.search(r"```(?:sql)?\s*([\s\S]*?)```", content)
|
||||
if match:
|
||||
sql = match.group(1).strip()
|
||||
if sql and sql.upper().strip().startswith("SELECT"):
|
||||
return sql
|
||||
for line in content.split("\n"):
|
||||
line = line.strip()
|
||||
if line.upper().startswith("SELECT"):
|
||||
return line
|
||||
return None
|
||||
|
||||
def _sanitize_sql_typos(sql: str) -> str:
|
||||
if not sql:
|
||||
return sql
|
||||
sql = re.sub(r"WHEN([A-Za-z_][A-Za-z0-9_.\"]*)", r"WHEN \1", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r"\bLAGerplatz_Artikel\b", "Lagerplatz_Artikel", sql)
|
||||
sql = re.sub(r"\bLAGerplatz\b", "Lagerplatz", sql)
|
||||
sql = sql.replace('"Einkaufspreis_neu"', '"Einkaufspreis"')
|
||||
sql = sql.replace("Einkaufspreis_neu.", "Einkaufspreis.")
|
||||
sql = re.sub(r'"Einkaufspreis"\."ARTIKEL"', '"Einkaufspreis"."m_Artikel"', sql)
|
||||
return sql
|
||||
|
||||
async def agent_sql_plan_node(state: ChatState) -> dict:
|
||||
ctx = _get_graph_context()
|
||||
ctx_len = _get_context_length(ctx)
|
||||
max_system_chars = max(1000, int(ctx_len * 0.8 - reserved_tokens) * bytes_per_token) - len(SQL_PLAN_SUFFIX)
|
||||
schema_part = ctx.prompt_sections.get("schema") or ctx.prompt_sections.get("intro", "")
|
||||
intro_part = (ctx.prompt_sections.get("intro", "") or "")[:400]
|
||||
combined = f"{intro_part}\n\n{schema_part}" if intro_part else schema_part
|
||||
system_content = _truncate_system_prompt(combined, max_system_chars, SCHEMA_TRUNCATION_SUFFIX) + SQL_PLAN_SUFFIX
|
||||
llm = ctx.model
|
||||
return await _agent_common(state, system_content, llm, "agent_sql_plan")
|
||||
|
||||
async def parse_execute_sql_node(state: ChatState) -> dict:
|
||||
ctx = _get_graph_context()
|
||||
sql_t = ctx.sql_tool
|
||||
last_msg = state.messages[-1] if state.messages else None
|
||||
if not isinstance(last_msg, AIMessage):
|
||||
return {"messages": [ToolMessage(content="Fehler: Keine AI-Antwort zum Parsen.", tool_call_id="parse_0", name="sqlite_query")]}
|
||||
sql = _parse_sql_from_content(last_msg.content or "")
|
||||
if not sql or not sql_t:
|
||||
return {"messages": [ToolMessage(content="Konnte keine SQL-Abfrage aus der Antwort extrahieren.", tool_call_id="parse_0", name="sqlite_query")]}
|
||||
sql = _sanitize_sql_typos(sql)
|
||||
try:
|
||||
result = await sql_t.ainvoke({"query": sql})
|
||||
except Exception as e:
|
||||
logger.error(f"SQL execution failed: {e}")
|
||||
result = f"Fehler bei der Ausführung: {e}"
|
||||
return {"messages": [ToolMessage(content=str(result), tool_call_id="parse_0", name="sqlite_query")]}
|
||||
|
||||
async def agent_formulate_node(state: ChatState) -> dict:
|
||||
ctx = _get_graph_context()
|
||||
human_content = ""
|
||||
tool_content = ""
|
||||
for m in state.messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
human_content = m.content or ""
|
||||
if isinstance(m, ToolMessage) and getattr(m, "name", "") == "sqlite_query":
|
||||
tool_content = m.content or ""
|
||||
if not tool_content or not tool_content.strip():
|
||||
logger.warning("agent_formulate: no tool_content (sqlite_query) in state.messages")
|
||||
return {"messages": [AIMessage(content="Die Datenbankabfrage konnte keine Ergebnisse liefern. Bitte versuchen Sie es mit einer anderen Formulierung.")]}
|
||||
if "Query failed" in tool_content or tool_content.strip().startswith("Error"):
|
||||
err_summary = "Die Datenbankabfrage ist fehlgeschlagen."
|
||||
if "no such column" in tool_content:
|
||||
err_summary += " Ein Spaltenname scheint nicht zu passen. Bitte die Anfrage anders formulieren."
|
||||
return {"messages": [AIMessage(content=err_summary)]}
|
||||
formatted_data = _tool_output_to_markdown_table(tool_content)
|
||||
logger.debug(f"agent_formulate: tool_content length={len(tool_content)}, formatted={len(formatted_data)}")
|
||||
ctx_len = _get_context_length(ctx)
|
||||
max_system_chars = max(3000, int(ctx_len * 0.5) * bytes_per_token) - len(FORMULATE_TASK)
|
||||
resp_struct = ctx.prompt_sections.get("response_structure") or ctx.prompt_sections.get("intro", "")
|
||||
intro_formulate = ctx.prompt_sections.get("intro", "")
|
||||
combined = f"{intro_formulate}\n\n{resp_struct}" if intro_formulate != resp_struct else resp_struct
|
||||
if len(combined) + len(FORMULATE_TASK) > max_system_chars:
|
||||
combined = _truncate_system_prompt(combined, max_system_chars - len(FORMULATE_TASK), "")
|
||||
system_content = combined + FORMULATE_TASK
|
||||
prompt = (
|
||||
f"Benutzerfrage: {human_content}\n\n"
|
||||
"--- VORGEGEBENE DATEN (diese Tabelle/Zahlen UNVERÄNDERT in die Antwort übernehmen): ---\n"
|
||||
f"{formatted_data}\n\n"
|
||||
"Die obige Tabelle bzw. Zahlen sind die EINZIGEN erlaubten Daten. Kopiere sie 1:1. "
|
||||
"Berechne keine eigenen Summen/Anzahlen - nutze die gelieferten Werte. Formuliere die Antwort:"
|
||||
)
|
||||
window = [SystemMessage(content=system_content), HumanMessage(content=prompt)]
|
||||
try:
|
||||
response = await ctx.model.ainvoke(window)
|
||||
except ValueError as exc:
|
||||
if "No suitable model found" in str(exc):
|
||||
response = AIMessage(content="Es gab einen Fehler bei der Formulierung. Bitte versuchen Sie es erneut.")
|
||||
else:
|
||||
raise
|
||||
if response.content:
|
||||
response = AIMessage(content=_sanitize_llm_response(response.content))
|
||||
return {"messages": [response]}
|
||||
|
||||
async def agent_tavily_node(state: ChatState) -> dict:
|
||||
ctx = _get_graph_context()
|
||||
resp_struct = ctx.prompt_sections.get("response_structure") or ""
|
||||
intro_tavily = ctx.prompt_sections.get("intro", "")
|
||||
combined = f"{intro_tavily}\n\n{resp_struct}" if resp_struct else intro_tavily
|
||||
system_content = _truncate_system_prompt(combined, 6000, "")
|
||||
tools_tavily = [t for t in [ctx.tavily_tool, ctx.streaming_tool] if t is not None]
|
||||
llm_tavily = ctx.model.bind_tools(tools=tools_tavily) if tools_tavily else ctx.model
|
||||
return await _agent_common(state, system_content, llm_tavily, "agent_tavily")
|
||||
|
||||
async def agent_answer_node(state: ChatState) -> dict:
|
||||
ctx = _get_graph_context()
|
||||
resp_struct = ctx.prompt_sections.get("response_structure") or ""
|
||||
intro_answer = ctx.prompt_sections.get("intro", "")
|
||||
combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer
|
||||
system_content = _truncate_system_prompt(combined, 6000, "")
|
||||
llm = ctx.planner_model if ctx.planner_model else ctx.model
|
||||
return await _agent_common(state, system_content, llm, "agent_answer")
|
||||
|
||||
def should_continue_tavily(state: ChatState) -> str:
|
||||
last = state.messages[-1]
|
||||
return "tools" if getattr(last, "tool_calls", None) else END
|
||||
|
||||
def route_back(state: ChatState) -> str:
|
||||
ctx = _get_graph_context()
|
||||
return "agent_tavily" if ctx.tavily_tool else "agent_answer"
|
||||
|
||||
async def tools_with_retry(state: ChatState) -> dict:
|
||||
import asyncio
|
||||
ctx = _get_graph_context()
|
||||
last_message = state.messages[-1]
|
||||
tool_calls = getattr(last_message, "tool_calls", [])
|
||||
if not tool_calls:
|
||||
return {"messages": []}
|
||||
tools_by_name = ctx.tools_by_name
|
||||
|
||||
async def execute_single_tool(tool_call):
|
||||
tool_name = tool_call.get("name") or tool_call.get("function", {}).get("name")
|
||||
tool_id = tool_call.get("id", f"call_{tool_name}")
|
||||
args = tool_call.get("args") or tool_call.get("function", {}).get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
import json
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except Exception:
|
||||
args = {"input": args}
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if not tool:
|
||||
return ToolMessage(content=f"Error: Tool '{tool_name}' not found", tool_call_id=tool_id, name=tool_name)
|
||||
try:
|
||||
if hasattr(tool, "coroutine") and asyncio.iscoroutinefunction(tool.coroutine):
|
||||
result = await tool.coroutine(**args)
|
||||
elif hasattr(tool, "ainvoke"):
|
||||
result = await tool.ainvoke(args)
|
||||
else:
|
||||
result = tool.invoke(args)
|
||||
return ToolMessage(content=str(result), tool_call_id=tool_id, name=tool_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool {tool_name} failed: {e}")
|
||||
return ToolMessage(content=f"Error executing {tool_name}: {str(e)}", tool_call_id=tool_id, name=tool_name)
|
||||
|
||||
tool_messages = await asyncio.gather(
|
||||
*[execute_single_tool(tc) for tc in tool_calls],
|
||||
return_exceptions=True
|
||||
)
|
||||
result_messages = []
|
||||
for i, msg in enumerate(tool_messages):
|
||||
if isinstance(msg, Exception):
|
||||
tool_call = tool_calls[i]
|
||||
tool_name = tool_call.get("name", "unknown")
|
||||
tool_id = tool_call.get("id", f"call_{i}")
|
||||
result_messages.append(ToolMessage(content=f"Error: {str(msg)}", tool_call_id=tool_id, name=tool_name))
|
||||
else:
|
||||
result_messages.append(msg)
|
||||
result = {"messages": result_messages}
|
||||
no_results_keywords = [
|
||||
"returned 0 rows", "no data", "keine artikel gefunden", "keine ergebnisse"
|
||||
]
|
||||
for msg in result.get("messages", []):
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, str):
|
||||
content_lower = content.lower()
|
||||
if any(keyword in content_lower for keyword in no_results_keywords):
|
||||
retry_count = sum(1 for m in state.messages if "retry" in str(getattr(m, "content", "")).lower())
|
||||
if retry_count < 2:
|
||||
logger.info("No results found in tool output, adding retry instruction")
|
||||
retry_message = HumanMessage(
|
||||
content="WICHTIG: Die vorherige Suche hat keine Ergebnisse gefunden. "
|
||||
"Bitte versuche eine alternative Suchstrategie:\n"
|
||||
"1. Wenn die Frage im Format 'X von Y' war (z.B. 'Lampen von Eaton'), "
|
||||
"verwende IMMER eine Kombination aus Lieferanten-Filter (WHERE a.\"Lieferant\" LIKE '%Y%') "
|
||||
"UND Produkttyp-Filter (WHERE a.\"Artikelbezeichnung\" LIKE '%X%' OR ...)\n"
|
||||
"2. Verwende mehrere Synonyme für den Produkttyp (z.B. bei 'Lampen': Lampe, LED, Beleuchtung, Licht, Leuchte, Strahler)\n"
|
||||
"3. Führe zuerst eine COUNT-Abfrage durch, dann die Detail-Abfrage mit Lagerbeständen\n"
|
||||
"4. Verwende LIKE '%Lieferant%' für den Lieferanten-Filter, um auch Varianten zu finden"
|
||||
)
|
||||
result["messages"].append(retry_message)
|
||||
break
|
||||
return result
|
||||
|
||||
workflow = StateGraph(ChatState)
|
||||
workflow.add_node("planner", planner_node)
|
||||
workflow.add_node("agent_sql_plan", agent_sql_plan_node)
|
||||
workflow.add_node("parse_execute_sql", parse_execute_sql_node)
|
||||
workflow.add_node("agent_formulate", agent_formulate_node)
|
||||
workflow.add_node("tools", tools_with_retry)
|
||||
workflow.add_node("agent_tavily", agent_tavily_node)
|
||||
workflow.add_node("agent_answer", agent_answer_node)
|
||||
workflow.add_edge(START, "planner")
|
||||
workflow.add_conditional_edges("planner", route_by_plan)
|
||||
workflow.add_edge("agent_sql_plan", "parse_execute_sql")
|
||||
workflow.add_edge("parse_execute_sql", "agent_formulate")
|
||||
workflow.add_edge("agent_formulate", END)
|
||||
workflow.add_conditional_edges("agent_tavily", should_continue_tavily)
|
||||
workflow.add_edge("agent_answer", END)
|
||||
workflow.add_conditional_edges("tools", route_back)
|
||||
return workflow.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chatbot:
|
||||
"""Represents a chatbot."""
|
||||
|
||||
model: AICenterChatModel
|
||||
memory: DatabaseCheckpointer
|
||||
planner_model: Optional[AICenterChatModel] = None # Fast model for routing (SQL/TAVILY/NONE)
|
||||
app: CompiledStateGraph = None
|
||||
_tools: List[Any] = field(default_factory=list) # Configured tools (for cached graph context)
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
workflow_id: str = "default"
|
||||
config: Optional["ChatbotConfig"] = None
|
||||
|
|
@ -189,6 +595,7 @@ class Chatbot:
|
|||
workflow_id: str = "default",
|
||||
config: Optional["ChatbotConfig"] = None,
|
||||
event_manager=None,
|
||||
planner_model: Optional[AICenterChatModel] = None,
|
||||
) -> "Chatbot":
|
||||
"""Factory method to create and configure a Chatbot instance.
|
||||
|
||||
|
|
@ -199,6 +606,7 @@ class Chatbot:
|
|||
workflow_id: The workflow ID (maps to thread_id).
|
||||
config: Optional chatbot configuration for dynamic tool enablement.
|
||||
event_manager: Optional event manager for streaming (passed from route).
|
||||
planner_model: Optional fast model for planner/routing (default: same as model).
|
||||
|
||||
Returns:
|
||||
A configured Chatbot instance.
|
||||
|
|
@ -210,9 +618,11 @@ class Chatbot:
|
|||
workflow_id=workflow_id,
|
||||
config=config,
|
||||
_event_manager=event_manager,
|
||||
planner_model=planner_model,
|
||||
)
|
||||
configured_tools = await instance._configure_tools()
|
||||
instance.app = instance._build_app(memory, configured_tools)
|
||||
instance._tools = configured_tools
|
||||
instance.app = _get_or_build_cached_graph()
|
||||
return instance
|
||||
|
||||
async def _configure_tools(self) -> List[Any]:
|
||||
|
|
@ -281,6 +691,7 @@ class Chatbot:
|
|||
tools_sql = [t for t in [sql_tool, tavily_tool, streaming_tool] if t is not None]
|
||||
tools_tavily = [t for t in [tavily_tool, streaming_tool] if t is not None]
|
||||
llm_plain = self.model
|
||||
llm_planner = self.planner_model if self.planner_model else self.model
|
||||
# SQL path uses structured prompts + parse/execute (no native tool calling) - fits /api/analyze
|
||||
llm_tavily = self.model.bind_tools(tools=tools_tavily) if tools_tavily else self.model
|
||||
|
||||
|
|
@ -348,7 +759,8 @@ class Chatbot:
|
|||
|
||||
async def planner_node(state: ChatState) -> dict:
|
||||
"""Planner: minimal prompt, no tools. Outputs SQL/TAVILY/BOTH/NONE.
|
||||
Does NOT add planner message to chat - only sets state.plan for routing."""
|
||||
Does NOT add planner message to chat - only sets state.plan for routing.
|
||||
Uses llm_planner (fast model) when available for lower latency."""
|
||||
human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)]
|
||||
last_human = human_msgs[-1].content if human_msgs else ""
|
||||
window = [
|
||||
|
|
@ -357,7 +769,7 @@ class Chatbot:
|
|||
]
|
||||
plan = "SQL"
|
||||
try:
|
||||
response = await llm_plain.ainvoke(window)
|
||||
response = await llm_planner.ainvoke(window)
|
||||
except ValueError as exc:
|
||||
if "No suitable model found" in str(exc):
|
||||
logger.warning(f"Planner model selection failed: {exc}")
|
||||
|
|
@ -555,12 +967,12 @@ class Chatbot:
|
|||
return await _agent_common(state, system_content, llm_tavily, "agent_tavily")
|
||||
|
||||
async def agent_answer_node(state: ChatState) -> dict:
|
||||
"""Agent with no tools. Uses intro + response_structure."""
|
||||
"""Agent with no tools (plan NONE). Uses fast model for lower latency."""
|
||||
resp_struct = _prompt_sections["response_structure"] or ""
|
||||
intro_answer = _prompt_sections["intro"]
|
||||
combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer
|
||||
system_content = _truncate_system_prompt(combined, 6000, "")
|
||||
return await _agent_common(state, system_content, llm_plain, "agent_answer")
|
||||
return await _agent_common(state, system_content, llm_planner, "agent_answer")
|
||||
|
||||
def should_continue_tavily(state: ChatState) -> str:
|
||||
last = state.messages[-1]
|
||||
|
|
@ -722,16 +1134,29 @@ class Chatbot:
|
|||
Returns:
|
||||
The list of messages in the chat history.
|
||||
"""
|
||||
# Set the right thread ID for memory
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
# Single-turn chat (non-streaming)
|
||||
result = await self.app.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}, config=config
|
||||
tools_by_name = {t.name: t for t in self._tools}
|
||||
graph_ctx = ChatbotGraphContext(
|
||||
model=self.model,
|
||||
planner_model=self.planner_model or self.model,
|
||||
tools=self._tools,
|
||||
tools_by_name=tools_by_name,
|
||||
sql_tool=tools_by_name.get("sqlite_query"),
|
||||
tavily_tool=tools_by_name.get("tavily_search"),
|
||||
streaming_tool=tools_by_name.get("send_streaming_message"),
|
||||
prompt_sections=_split_system_prompt(self.system_prompt),
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
# Extract and return the messages from the result
|
||||
return result["messages"]
|
||||
ctx_token = _set_graph_context(graph_ctx)
|
||||
cp_token = set_checkpointer(self.memory)
|
||||
try:
|
||||
result = await self.app.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}, config=config
|
||||
)
|
||||
return result["messages"]
|
||||
finally:
|
||||
_reset_graph_context(ctx_token)
|
||||
reset_checkpointer(cp_token)
|
||||
|
||||
async def stream_events(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
|
|
@ -757,6 +1182,21 @@ class Chatbot:
|
|||
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
||||
return not ev.get("parent_ids")
|
||||
|
||||
# Build tool lookup for cached graph context
|
||||
tools_by_name = {t.name: t for t in self._tools}
|
||||
graph_ctx = ChatbotGraphContext(
|
||||
model=self.model,
|
||||
planner_model=self.planner_model or self.model,
|
||||
tools=self._tools,
|
||||
tools_by_name=tools_by_name,
|
||||
sql_tool=tools_by_name.get("sqlite_query"),
|
||||
tavily_tool=tools_by_name.get("tavily_search"),
|
||||
streaming_tool=tools_by_name.get("send_streaming_message"),
|
||||
prompt_sections=_split_system_prompt(self.system_prompt),
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
ctx_token = _set_graph_context(graph_ctx)
|
||||
cp_token = set_checkpointer(self.memory)
|
||||
try:
|
||||
async for event in self.app.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
|
|
@ -767,6 +1207,25 @@ class Chatbot:
|
|||
ename = event.get("name") or ""
|
||||
edata = event.get("data") or {}
|
||||
|
||||
# Stream LLM tokens for ChatGPT-like incremental display
|
||||
if etype in ("on_llm_stream", "on_chat_model_stream"):
|
||||
ch = edata.get("chunk")
|
||||
if ch is None:
|
||||
continue
|
||||
# Chunk can be string, AIMessageChunk (has .content), or dict
|
||||
content = ""
|
||||
if isinstance(ch, str):
|
||||
content = ch
|
||||
elif hasattr(ch, "content"):
|
||||
content = ch.content or ""
|
||||
if isinstance(content, list):
|
||||
content = "".join(str(x) for x in content)
|
||||
elif isinstance(ch, dict):
|
||||
content = ch.get("content", "") or ""
|
||||
if isinstance(content, str) and content:
|
||||
yield {"type": "chunk", "content": content}
|
||||
continue
|
||||
|
||||
# Stream human-readable progress via the special send_streaming_message tool
|
||||
# Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py)
|
||||
if etype == "on_tool_start":
|
||||
|
|
@ -833,3 +1292,6 @@ class Chatbot:
|
|||
# Emit a single error envelope and end the stream
|
||||
logger.error(f"Exception in stream_events: {exc}", exc_info=True)
|
||||
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}
|
||||
finally:
|
||||
_reset_graph_context(ctx_token)
|
||||
reset_checkpointer(cp_token)
|
||||
|
|
|
|||
|
|
@ -397,8 +397,8 @@ class ChatObjects:
|
|||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||||
|
||||
# Create database connector directly
|
||||
self.db = DatabaseConnector(
|
||||
from modules.connectors.connectorDbPostgre import _get_cached_connector
|
||||
self.db = _get_cached_connector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
|
|
@ -769,6 +769,72 @@ class ChatObjects:
|
|||
"""Backward-compat alias: workflowId maps to conversationId."""
|
||||
return self.getConversation(workflowId)
|
||||
|
||||
def getWorkflowMinimal(self, workflowId: str) -> Optional[ChatbotConversation]:
|
||||
"""Lightweight fetch: conversation record only, no logs/messages. For resume path."""
|
||||
conversations = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
ChatbotConversation,
|
||||
self.currentUser,
|
||||
recordFilter={"id": workflowId},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode="chatbot",
|
||||
)
|
||||
if not conversations:
|
||||
return None
|
||||
conv = conversations[0]
|
||||
max_steps = conv.get("maxSteps")
|
||||
return ChatbotConversation(
|
||||
id=conv["id"],
|
||||
featureInstanceId=conv.get("featureInstanceId") or self.featureInstanceId or "",
|
||||
name=conv.get("name"),
|
||||
status=conv.get("status", "running"),
|
||||
currentRound=conv.get("currentRound", 0) or 0,
|
||||
lastActivity=conv.get("lastActivity", getUtcTimestamp()),
|
||||
startedAt=conv.get("startedAt", getUtcTimestamp()),
|
||||
workflowMode=ChatbotWorkflowModeEnum(conv.get("workflowMode", "Chatbot")),
|
||||
maxSteps=max_steps if max_steps is not None else 10,
|
||||
logs=[],
|
||||
messages=[],
|
||||
)
|
||||
|
||||
def getMessageCount(self, conversationId: str) -> int:
|
||||
"""Returns message count for a conversation (single query, no document fetch)."""
|
||||
messages = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
ChatbotMessage,
|
||||
self.currentUser,
|
||||
recordFilter={"conversationId": conversationId},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode="chatbot",
|
||||
)
|
||||
return len(messages) if messages else 0
|
||||
|
||||
def updateWorkflowMinimal(
|
||||
self, workflowId: str, workflowData: Dict[str, Any]
|
||||
) -> ChatbotConversation:
|
||||
"""Lightweight update: no logs/messages fetch. For resume when caller has minimal workflow."""
|
||||
if not self.checkRbacPermission(ChatbotConversation, "update", workflowId):
|
||||
raise PermissionError(f"No permission to update conversation {workflowId}")
|
||||
simpleFields, _ = self._separateObjectFields(ChatbotConversation, workflowData)
|
||||
simpleFields["lastActivity"] = getUtcTimestamp()
|
||||
updated = self.db.recordModify(ChatbotConversation, workflowId, simpleFields)
|
||||
max_steps = updated.get("maxSteps")
|
||||
return ChatbotConversation(
|
||||
id=updated["id"],
|
||||
featureInstanceId=updated.get("featureInstanceId") or self.featureInstanceId or "",
|
||||
name=updated.get("name"),
|
||||
status=updated.get("status", "running"),
|
||||
currentRound=updated.get("currentRound", 0) or 0,
|
||||
lastActivity=updated.get("lastActivity", getUtcTimestamp()),
|
||||
startedAt=updated.get("startedAt", getUtcTimestamp()),
|
||||
workflowMode=ChatbotWorkflowModeEnum(updated.get("workflowMode", "Chatbot")),
|
||||
maxSteps=max_steps if max_steps is not None else 10,
|
||||
logs=[],
|
||||
messages=[],
|
||||
)
|
||||
|
||||
def createConversation(self, conversationData: Dict[str, Any]) -> ChatbotConversation:
|
||||
"""Creates a new conversation if user has permission."""
|
||||
if not self.checkRbacPermission(ChatbotConversation, "create"):
|
||||
|
|
@ -825,9 +891,7 @@ class ChatObjects:
|
|||
|
||||
updated = self.db.recordModify(ChatbotConversation, conversationId, simpleFields)
|
||||
|
||||
logs = self.getLogs(conversationId)
|
||||
messages = self.getMessages(conversationId)
|
||||
|
||||
# Reuse logs/messages from conv — update only touches simple fields, not related data
|
||||
return ChatbotConversation(
|
||||
id=updated["id"],
|
||||
featureInstanceId=updated.get("featureInstanceId") or conv.featureInstanceId or self.featureInstanceId or "",
|
||||
|
|
@ -838,8 +902,8 @@ class ChatObjects:
|
|||
startedAt=updated.get("startedAt", conv.startedAt),
|
||||
workflowMode=ChatbotWorkflowModeEnum(updated.get("workflowMode", conv.workflowMode.value)),
|
||||
maxSteps=updated.get("maxSteps") if updated.get("maxSteps") is not None else conv.maxSteps,
|
||||
logs=logs,
|
||||
messages=messages
|
||||
logs=conv.logs,
|
||||
messages=conv.messages
|
||||
)
|
||||
|
||||
def updateWorkflow(self, workflowId: str, workflowData: Dict[str, Any]) -> ChatbotConversation:
|
||||
|
|
@ -955,11 +1019,13 @@ class ChatObjects:
|
|||
if pagination and pagination.sort:
|
||||
messageDicts = self._applySorting(messageDicts, pagination.sort)
|
||||
|
||||
# If no pagination requested, return all items
|
||||
# If no pagination requested, return all items (batch-fetch documents to avoid N+1)
|
||||
if pagination is None:
|
||||
msg_ids = [m["id"] for m in messageDicts]
|
||||
docs_by_message = self.getDocumentsForMessages(msg_ids) if msg_ids else {}
|
||||
chat_messages = []
|
||||
for msg in messageDicts:
|
||||
documents = self.getDocuments(msg["id"])
|
||||
documents = docs_by_message.get(msg["id"], [])
|
||||
chat_message = ChatbotMessage(
|
||||
id=msg["id"],
|
||||
conversationId=msg["conversationId"],
|
||||
|
|
@ -994,10 +1060,11 @@ class ChatObjects:
|
|||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||||
endIdx = startIdx + pagination.pageSize
|
||||
pagedMessageDicts = messageDicts[startIdx:endIdx]
|
||||
|
||||
paged_msg_ids = [m["id"] for m in pagedMessageDicts]
|
||||
docs_by_message = self.getDocumentsForMessages(paged_msg_ids) if paged_msg_ids else {}
|
||||
chat_messages = []
|
||||
for msg in pagedMessageDicts:
|
||||
documents = self.getDocuments(msg["id"])
|
||||
documents = docs_by_message.get(msg["id"], [])
|
||||
chat_message = ChatbotMessage(
|
||||
id=msg["id"],
|
||||
conversationId=msg["conversationId"],
|
||||
|
|
@ -1224,6 +1291,17 @@ class ChatObjects:
|
|||
logger.error(f"Error updating message {messageId}: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Error updating message {messageId}: {str(e)}")
|
||||
|
||||
def createStat(self, statData: Dict[str, Any]):
|
||||
"""Create stat record. Compatibility with ChatService; stats may not be persisted in chatbot schema."""
|
||||
from modules.datamodels.datamodelChat import ChatStat
|
||||
stat = ChatStat(**statData)
|
||||
try:
|
||||
created = self.db.recordCreate(ChatStat, statData)
|
||||
return ChatStat(**created)
|
||||
except Exception as e:
|
||||
logger.debug(f"createStat: not persisting (chatbot schema): {e}")
|
||||
return stat
|
||||
|
||||
def deleteMessage(self, conversationId: str, messageId: str) -> bool:
|
||||
"""Deletes a conversation message and related data if user has access."""
|
||||
try:
|
||||
|
|
@ -1308,6 +1386,30 @@ class ChatObjects:
|
|||
logger.error(f"Error getting message documents: {str(e)}")
|
||||
return []
|
||||
|
||||
def getDocumentsForMessages(self, messageIds: List[str]) -> Dict[str, List[ChatbotDocument]]:
|
||||
"""Returns documents for multiple messages in one query. Returns {messageId: [doc, ...]}."""
|
||||
if not messageIds:
|
||||
return {}
|
||||
try:
|
||||
documents = getRecordsetWithRBAC(
|
||||
self.db,
|
||||
ChatbotDocument,
|
||||
self.currentUser,
|
||||
recordFilter={"messageId": messageIds},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode="chatbot",
|
||||
)
|
||||
result: Dict[str, List[ChatbotDocument]] = {mid: [] for mid in messageIds}
|
||||
for doc in documents:
|
||||
mid = doc.get("messageId")
|
||||
if mid in result:
|
||||
result[mid].append(ChatbotDocument(**doc))
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting documents for messages: {e}")
|
||||
return {mid: [] for mid in messageIds}
|
||||
|
||||
def createDocument(self, documentData: Dict[str, Any]) -> ChatbotDocument:
|
||||
"""Creates a document for a message in normalized table."""
|
||||
try:
|
||||
|
|
@ -1451,11 +1553,14 @@ class ChatObjects:
|
|||
|
||||
items = []
|
||||
messages = getRecordsetWithRBAC(self.db, ChatbotMessage, self.currentUser, recordFilter={"conversationId": conversationId}, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode="chatbot")
|
||||
# Batch-fetch documents for all messages (avoids N+1)
|
||||
message_ids = [m["id"] for m in messages if afterTimestamp is None or parseTimestamp(m.get("publishedAt"), default=getUtcTimestamp()) > afterTimestamp]
|
||||
docs_by_message = self.getDocumentsForMessages(message_ids) if message_ids else {}
|
||||
for msg in messages:
|
||||
msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp())
|
||||
if afterTimestamp is not None and msgTimestamp <= afterTimestamp:
|
||||
continue
|
||||
documents = self.getDocuments(msg["id"])
|
||||
documents = docs_by_message.get(msg["id"], [])
|
||||
chatMessage = ChatbotMessage(
|
||||
id=msg["id"],
|
||||
conversationId=msg["conversationId"],
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Handles feature initialization and RBAC catalog registration.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -48,6 +48,14 @@ RESOURCE_OBJECTS = [
|
|||
},
|
||||
]
|
||||
|
||||
# Service requirements for chatbot — resolved via service center
|
||||
REQUIRED_SERVICES = [
|
||||
{"serviceKey": "chat", "meta": {"usage": "File info, document handling"}},
|
||||
{"serviceKey": "ai", "meta": {"usage": "AI calls, conversation name generation"}},
|
||||
{"serviceKey": "billing", "meta": {"usage": "Usage tracking, balance checks"}},
|
||||
{"serviceKey": "streaming", "meta": {"usage": "Event manager, ChatStreamingHelper"}},
|
||||
]
|
||||
|
||||
# Template roles for this feature
|
||||
# Role names MUST follow convention: {featureCode}-{roleName}
|
||||
TEMPLATE_ROLES = [
|
||||
|
|
@ -170,6 +178,76 @@ def registerFeature(catalogService) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def getChatbotServices(
|
||||
user,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
workflow=None,
|
||||
) -> "_ChatbotServiceHub":
|
||||
"""
|
||||
Get lightweight service hub for chatbot (chat, ai, streaming) without loading
|
||||
the full legacy Services hub. Avoids ~90 ms from _loadFeatureInterfaces +
|
||||
_loadFeatureServices; only instantiates required services.
|
||||
Uses interfaceFeatureChatbot (ChatObjects) for interfaceDbChat to avoid
|
||||
duplicate DB init - chatProcess reuses hub.interfaceDbChat.
|
||||
"""
|
||||
from modules.services import PublicService
|
||||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
|
||||
from modules.services.serviceChat.mainServiceChat import ChatService
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
from modules.services.serviceStreaming.mainServiceStreaming import StreamingService
|
||||
|
||||
hub = _ChatbotServiceHub()
|
||||
hub.user = user
|
||||
hub.mandateId = mandateId
|
||||
hub.featureInstanceId = featureInstanceId
|
||||
hub.workflow = workflow
|
||||
hub.featureCode = "chatbot"
|
||||
hub.allowedProviders = None
|
||||
|
||||
hub.interfaceDbApp = getAppInterface(user, mandateId=mandateId)
|
||||
# interfaceDbComponent: lazy-loaded on first access (saves ~100–300 ms when no file uploads)
|
||||
hub._interfaceDbComponent_val = None
|
||||
# Use ChatObjects (interfaceFeatureChatbot) - same as chatProcess, avoids extra interfaceDbChat init
|
||||
hub.interfaceDbChat = getChatbotInterface(
|
||||
user, mandateId=mandateId, featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
hub.chat = PublicService(ChatService(hub))
|
||||
hub.ai = PublicService(AiService(hub), functionsOnly=False)
|
||||
hub.streaming = PublicService(StreamingService(hub))
|
||||
|
||||
return hub
|
||||
|
||||
|
||||
class _ChatbotServiceHub:
|
||||
"""Lightweight hub with chat, ai, streaming for chatbot; avoids full Services init."""
|
||||
|
||||
user = None
|
||||
mandateId = None
|
||||
featureInstanceId = None
|
||||
workflow = None
|
||||
interfaceDbApp = None
|
||||
_interfaceDbComponent_val = None
|
||||
interfaceDbChat = None
|
||||
|
||||
@property
|
||||
def interfaceDbComponent(self):
|
||||
"""Lazy-load interfaceDbComponent on first access (saves ~100–300 ms when no files)."""
|
||||
if self._interfaceDbComponent_val is None:
|
||||
from modules.interfaces.interfaceDbManagement import getInterface as getComponentInterface
|
||||
self._interfaceDbComponent_val = getComponentInterface(
|
||||
self.user, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId
|
||||
)
|
||||
return self._interfaceDbComponent_val
|
||||
chat = None
|
||||
ai = None
|
||||
streaming = None
|
||||
featureCode = "chatbot"
|
||||
allowedProviders = None
|
||||
|
||||
|
||||
def _syncTemplateRolesToDb() -> int:
|
||||
"""
|
||||
Sync template roles and their AccessRules to the database.
|
||||
|
|
|
|||
|
|
@ -33,6 +33,17 @@ from modules.features.chatbot.interfaceFeatureChatbot import ChatbotConversation
|
|||
from modules.features.chatbot import chatProcess
|
||||
from modules.services.serviceStreaming import get_event_manager
|
||||
|
||||
# Pre-warm AI connectors when this router loads (before first request).
|
||||
# Ensures connectors are ready; avoids 4–8 s delay on first chatbot message.
|
||||
try:
|
||||
import modules.aicore.aicoreModelRegistry # noqa: F401
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
modelRegistry.ensureConnectorsRegistered()
|
||||
modelRegistry.refreshModels(force=True)
|
||||
logging.getLogger(__name__).info("Chatbot router: AI connectors pre-warmed")
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Chatbot AI pre-warm failed: {e}")
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -43,12 +54,14 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None):
|
||||
"""Get chatbot interface with instance context."""
|
||||
mandateId = str(context.mandateId) if context.mandateId else None
|
||||
def _getServiceChat(context: RequestContext, instanceId: Optional[str] = None, mandateId: Optional[str] = None):
|
||||
"""Get chatbot interface with instance context.
|
||||
Pass mandateId when available (e.g. from _validateInstanceAccess) to ensure cache hit with getChatbotServices.
|
||||
"""
|
||||
effective_mandate = mandateId if mandateId is not None else (str(context.mandateId) if context.mandateId else None)
|
||||
return interfaceDbChat.getInterface(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
context.user,
|
||||
mandateId=effective_mandate,
|
||||
featureInstanceId=instanceId
|
||||
)
|
||||
|
||||
|
|
@ -125,7 +138,7 @@ def get_chatbot_threads(
|
|||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
try:
|
||||
interfaceDbChat = _getServiceChat(context, instanceId)
|
||||
interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
|
||||
|
||||
if workflowId:
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
|
|
@ -266,12 +279,14 @@ async def stream_chatbot_start(
|
|||
async def event_stream():
|
||||
"""Async generator for SSE events - pure event-driven streaming (no polling)."""
|
||||
try:
|
||||
# Get interface for initial data and status checks
|
||||
interfaceDbChat = _getServiceChat(context, instanceId)
|
||||
|
||||
# Get current workflow to check if resuming and get current round
|
||||
current_workflow = interfaceDbChat.getWorkflow(workflow.id)
|
||||
current_round = current_workflow.currentRound if current_workflow else None
|
||||
# Yield keepalive immediately so client gets 200 + first byte fast (normal chatbot feel)
|
||||
yield ": keepalive\n\n"
|
||||
|
||||
# Use same mandateId as chatProcess so we hit interface cache (avoid duplicate DB init)
|
||||
interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
|
||||
|
||||
# Use workflow from chatProcess (no refetch)
|
||||
current_round = workflow.currentRound if workflow else None
|
||||
is_resuming = final_workflow_id is not None and current_round and current_round > 1
|
||||
|
||||
# Send initial chat data (exact format as chatData endpoint) - only once at start
|
||||
|
|
@ -358,7 +373,7 @@ async def stream_chatbot_start(
|
|||
event_type = event.get("type")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# Emit chatdata events (messages, logs, stats, status) in exact chatData format
|
||||
# Emit chatdata events (messages, logs, stats, status, chunk) in exact chatData format
|
||||
if event_type == "chatdata" and event_data:
|
||||
# Handle status events (transient UI feedback)
|
||||
if event_data.get("type") == "status":
|
||||
|
|
@ -368,6 +383,13 @@ async def stream_chatbot_start(
|
|||
"label": event_data.get("label", "")
|
||||
}
|
||||
yield f"data: {json.dumps(status_item)}\n\n"
|
||||
elif event_data.get("type") == "chunk":
|
||||
# Token chunks for ChatGPT-like streaming
|
||||
chunk_item = {
|
||||
"type": "chunk",
|
||||
"content": event_data.get("content", "")
|
||||
}
|
||||
yield f"data: {json.dumps(chunk_item)}\n\n"
|
||||
else:
|
||||
# Emit other chatdata items (messages, logs, stats) in exact chatData format
|
||||
chatdata_item = event_data
|
||||
|
|
@ -444,7 +466,7 @@ async def stop_chatbot(
|
|||
|
||||
try:
|
||||
# Get chatbot interface with instance context
|
||||
interfaceDbChat = _getServiceChat(context, instanceId)
|
||||
interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
|
||||
|
||||
# Get workflow to verify it exists and belongs to this instance
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
|
|
@ -519,8 +541,7 @@ def delete_chatbot(
|
|||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
|
||||
try:
|
||||
# Get service center
|
||||
interfaceDbChat = _getServiceChat(context, instanceId)
|
||||
interfaceDbChat = _getServiceChat(context, instanceId, mandateId=mandateId)
|
||||
|
||||
# Get workflow directly (interface already handles mandate filtering)
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from modules.datamodels.datamodelUam import User
|
|||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, ProcessingModeEnum
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList, DocumentItemReference
|
||||
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
||||
from modules.services import getInterface as getServices
|
||||
from modules.features.chatbot.mainChatbot import getChatbotServices
|
||||
from modules.features.chatbot.chatbot import Chatbot
|
||||
from modules.features.chatbot.bridges.ai import AICenterChatModel, clear_workflow_allowed_providers
|
||||
from modules.features.chatbot.bridges.memory import DatabaseCheckpointer
|
||||
|
|
@ -91,33 +91,28 @@ async def chatProcess(
|
|||
ChatbotConversation instance
|
||||
"""
|
||||
try:
|
||||
# Get services with mandate and feature instance context
|
||||
services = getServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
# Get services from service center (only chat, ai, billing, streaming — avoids ~90ms legacy hub)
|
||||
services = getChatbotServices(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
services.featureCode = 'chatbot'
|
||||
|
||||
# Load instance config and apply allowedProviders for AI calls (conversation name + main chat)
|
||||
chatbot_config = await _load_chatbot_config(featureInstanceId)
|
||||
if chatbot_config.model.allowedProviders:
|
||||
services.allowedProviders = chatbot_config.model.allowedProviders
|
||||
logger.info(f"Chatbot instance {featureInstanceId}: restricting to providers {chatbot_config.model.allowedProviders}")
|
||||
|
||||
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
|
||||
interfaceDbChat = getChatbotInterface(currentUser, mandateId=mandateId, featureInstanceId=featureInstanceId)
|
||||
|
||||
# Config and model warm run in background task — return stream ~2–3 s faster for normal feel
|
||||
chatbot_config = None
|
||||
|
||||
# Reuse hub's interfaceDbChat (ChatObjects) - avoids duplicate DB init
|
||||
interfaceDbChat = services.interfaceDbChat
|
||||
|
||||
# Create or load workflow (event_manager passed from route)
|
||||
if workflowId:
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
# Lightweight resume: minimal fetch + minimal update (no logs/messages load)
|
||||
workflow = interfaceDbChat.getWorkflowMinimal(workflowId)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {workflowId} not found")
|
||||
|
||||
# Resume workflow: increment round number
|
||||
new_round = workflow.currentRound + 1
|
||||
interfaceDbChat.updateWorkflow(workflowId, {
|
||||
workflow = interfaceDbChat.updateWorkflowMinimal(workflowId, {
|
||||
"status": "running",
|
||||
"currentRound": new_round,
|
||||
"lastActivity": getUtcTimestamp()
|
||||
})
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
logger.info(f"Resumed workflow {workflowId}, round incremented to {new_round}")
|
||||
|
||||
# Create event queue if it doesn't exist (for streaming)
|
||||
|
|
@ -166,10 +161,7 @@ async def chatProcess(
|
|||
|
||||
# Create event queue for new workflow (for streaming)
|
||||
event_manager.create_queue(workflow.id)
|
||||
|
||||
# Reload workflow to get current message count
|
||||
workflow = interfaceDbChat.getWorkflow(workflow.id)
|
||||
|
||||
|
||||
# Process uploaded files and create ChatbotDocuments
|
||||
user_documents = []
|
||||
if userInput.listFileId and len(userInput.listFileId) > 0:
|
||||
|
|
@ -203,14 +195,19 @@ async def chatProcess(
|
|||
except Exception as e:
|
||||
logger.error(f"Error processing file ID {fileId}: {e}", exc_info=True)
|
||||
|
||||
# Store user message
|
||||
# Store user message (sequenceNr: for resume use message count, else len+1)
|
||||
seq_nr = (
|
||||
interfaceDbChat.getMessageCount(workflow.id) + 1
|
||||
if workflowId
|
||||
else len(workflow.messages) + 1
|
||||
)
|
||||
userMessageData: Dict[str, Any] = {
|
||||
"id": f"msg_{uuid.uuid4()}",
|
||||
"conversationId": workflow.id,
|
||||
"message": userInput.prompt,
|
||||
"role": "user",
|
||||
"status": "first" if workflowId is None else "step",
|
||||
"sequenceNr": len(workflow.messages) + 1,
|
||||
"sequenceNr": seq_nr,
|
||||
"publishedAt": getUtcTimestamp(),
|
||||
"roundNumber": workflow.currentRound,
|
||||
"taskNumber": 0,
|
||||
|
|
@ -223,11 +220,12 @@ async def chatProcess(
|
|||
userMessage = interfaceDbChat.createMessage(userMessageData, event_manager=None)
|
||||
logger.info(f"Stored user message: {userMessage.id} with {len(user_documents)} document(s)")
|
||||
|
||||
# Update workflow status
|
||||
interfaceDbChat.updateWorkflow(workflow.id, {
|
||||
"status": "running",
|
||||
"lastActivity": getUtcTimestamp()
|
||||
})
|
||||
# Update workflow status (minimal update for lastActivity; resume already did this)
|
||||
if not workflowId:
|
||||
interfaceDbChat.updateWorkflowMinimal(workflow.id, {
|
||||
"status": "running",
|
||||
"lastActivity": getUtcTimestamp()
|
||||
})
|
||||
|
||||
# Pre-flight billing check before starting LangGraph (if mandateId present)
|
||||
if mandateId:
|
||||
|
|
@ -244,9 +242,6 @@ async def chatProcess(
|
|||
config=chatbot_config,
|
||||
event_manager=event_manager
|
||||
))
|
||||
|
||||
# Reload workflow to include new message
|
||||
workflow = interfaceDbChat.getWorkflow(workflow.id)
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -728,8 +723,7 @@ async def _convert_file_ids_to_document_references(
|
|||
if not document_id:
|
||||
try:
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||||
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
|
||||
chatbotInterface = getChatbotInterface(services.user, mandateId=services.mandateId, featureInstanceId=services.featureInstanceId)
|
||||
chatbotInterface = services.interfaceDbChat
|
||||
documents = getRecordsetWithRBAC(
|
||||
chatbotInterface.db,
|
||||
ChatbotDocument,
|
||||
|
|
@ -928,6 +922,20 @@ async def _bridge_chatbot_events(
|
|||
step="status"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle token chunks for ChatGPT-like streaming (append to message as it's generated)
|
||||
if event_type == "chunk":
|
||||
content = event.get("content", "")
|
||||
if content:
|
||||
await event_manager.emit_event(
|
||||
context_id=workflow_id,
|
||||
event_type="chatdata",
|
||||
data={"type": "chunk", "content": content},
|
||||
event_category="chat",
|
||||
message="Token chunk",
|
||||
step="chunk"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle final response
|
||||
if event_type == "final":
|
||||
|
|
@ -1117,9 +1125,59 @@ async def _bridge_chatbot_events(
|
|||
clear_workflow_allowed_providers(workflow_id)
|
||||
|
||||
|
||||
def _load_chatbot_config_sync(featureInstanceId: Optional[str]) -> ChatbotConfig:
|
||||
"""
|
||||
Load chatbot configuration from FeatureInstance (database). Sync version for use in executor.
|
||||
|
||||
Args:
|
||||
featureInstanceId: Feature instance ID to load config from
|
||||
|
||||
Returns:
|
||||
ChatbotConfig instance
|
||||
"""
|
||||
if not featureInstanceId:
|
||||
raise ValueError("featureInstanceId is required to load chatbot config")
|
||||
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
featureInterface = getFeatureInterface(rootInterface.db)
|
||||
instance = featureInterface.getFeatureInstance(featureInstanceId)
|
||||
|
||||
if not instance:
|
||||
raise ValueError(f"FeatureInstance {featureInstanceId} not found")
|
||||
|
||||
logger.info(f"Loading chatbot config from FeatureInstance {featureInstanceId}")
|
||||
return load_chatbot_config_from_instance(instance)
|
||||
|
||||
|
||||
def _warm_model_registry_cache_sync(
|
||||
currentUser: "User",
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-warm getAvailableModels cache so planner/agent model selection is a cache hit.
|
||||
Uses mandateId/featureInstanceId for faster RBAC (fewer roles to load).
|
||||
Runs in executor to avoid blocking event loop.
|
||||
"""
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
|
||||
root = getRootInterface()
|
||||
modelRegistry.getAvailableModels(
|
||||
currentUser=currentUser,
|
||||
rbacInstance=root.rbac,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
)
|
||||
|
||||
|
||||
async def _load_chatbot_config(featureInstanceId: Optional[str]) -> ChatbotConfig:
|
||||
"""
|
||||
Load chatbot configuration from FeatureInstance (database).
|
||||
Runs in thread pool to avoid blocking event loop on DB I/O.
|
||||
|
||||
Args:
|
||||
featureInstanceId: Feature instance ID to load config from
|
||||
|
|
@ -1134,20 +1192,7 @@ async def _load_chatbot_config(featureInstanceId: Optional[str]) -> ChatbotConfi
|
|||
raise ValueError("featureInstanceId is required to load chatbot config")
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
||||
|
||||
# Get feature instance from database
|
||||
rootInterface = getRootInterface()
|
||||
featureInterface = getFeatureInterface(rootInterface.db)
|
||||
instance = featureInterface.getFeatureInstance(featureInstanceId)
|
||||
|
||||
if not instance:
|
||||
raise ValueError(f"FeatureInstance {featureInstanceId} not found")
|
||||
|
||||
logger.info(f"Loading chatbot config from FeatureInstance {featureInstanceId}")
|
||||
return load_chatbot_config_from_instance(instance)
|
||||
return await asyncio.to_thread(_load_chatbot_config_sync, featureInstanceId)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -1250,9 +1295,25 @@ async def _processChatbotMessageLangGraph(
|
|||
event_manager: Event manager for streaming (passed from chatProcess)
|
||||
"""
|
||||
try:
|
||||
from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface
|
||||
interfaceDbChat = getChatbotInterface(currentUser, mandateId=services.mandateId, featureInstanceId=featureInstanceId)
|
||||
|
||||
# Start config + model cache warm in parallel (planner/agent need cache hit to avoid 2–3 s per call)
|
||||
config_task = asyncio.create_task(_load_chatbot_config(featureInstanceId)) if config is None else None
|
||||
warm_task = asyncio.create_task(asyncio.to_thread(
|
||||
_warm_model_registry_cache_sync, currentUser, services.mandateId, featureInstanceId
|
||||
))
|
||||
|
||||
# Emit first status immediately so stream feels responsive
|
||||
await event_manager.emit_event(
|
||||
context_id=workflowId,
|
||||
event_type="chatdata",
|
||||
data={"type": "status", "label": "Starte..."},
|
||||
event_category="chat",
|
||||
message="Status update",
|
||||
step="status",
|
||||
)
|
||||
|
||||
# Reuse interfaceDbChat from services (ChatObjects) - avoids duplicate DB init
|
||||
interfaceDbChat = services.interfaceDbChat
|
||||
|
||||
# Reload workflow to get current messages
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
|
|
@ -1272,19 +1333,10 @@ async def _processChatbotMessageLangGraph(
|
|||
logger.info(f"Workflow {workflowId} was stopped, aborting processing")
|
||||
return
|
||||
|
||||
# Emit synthetic status for real-time UI feedback
|
||||
await event_manager.emit_event(
|
||||
context_id=workflowId,
|
||||
event_type="chatdata",
|
||||
data={"type": "status", "label": "Lade Konfiguration..."},
|
||||
event_category="chat",
|
||||
message="Status update",
|
||||
step="status"
|
||||
)
|
||||
|
||||
# Load configuration if not passed (e.g. when resuming)
|
||||
if config is None:
|
||||
config = await _load_chatbot_config(featureInstanceId)
|
||||
# Await config and model cache warm (planner gets cache hit, saves ~2–3 s)
|
||||
if config_task is not None:
|
||||
config = await config_task
|
||||
await warm_task
|
||||
|
||||
# Replace {{DATE}} placeholder in system prompt
|
||||
from datetime import datetime
|
||||
|
|
@ -1310,25 +1362,30 @@ async def _processChatbotMessageLangGraph(
|
|||
processing_mode=processing_mode,
|
||||
billing_callback=billing_callback,
|
||||
workflow_id=workflowId,
|
||||
allowed_providers=allowed_providers
|
||||
allowed_providers=allowed_providers,
|
||||
mandate_id=services.mandateId,
|
||||
feature_instance_id=featureInstanceId,
|
||||
)
|
||||
# Fast planner model (gpt-4o-mini etc.) for routing - saves ~1-2 s on first response
|
||||
planner_model = AICenterChatModel(
|
||||
user=currentUser,
|
||||
operation_type=operation_type,
|
||||
processing_mode=processing_mode,
|
||||
billing_callback=billing_callback,
|
||||
workflow_id=workflowId,
|
||||
allowed_providers=allowed_providers,
|
||||
prefer_fast_model=True,
|
||||
mandate_id=services.mandateId,
|
||||
feature_instance_id=featureInstanceId,
|
||||
)
|
||||
|
||||
# Emit synthetic status for real-time UI feedback
|
||||
await event_manager.emit_event(
|
||||
context_id=workflowId,
|
||||
event_type="chatdata",
|
||||
data={"type": "status", "label": "Bereite Chat vor..."},
|
||||
event_category="chat",
|
||||
message="Status update",
|
||||
step="status"
|
||||
)
|
||||
|
||||
# Create memory/checkpointer (uses chatbot's own DB via interfaceFeatureChatbot)
|
||||
# Create memory/checkpointer (reuse interface to avoid extra DB init)
|
||||
memory = DatabaseCheckpointer(
|
||||
user=currentUser,
|
||||
workflow_id=workflowId,
|
||||
mandateId=services.mandateId,
|
||||
featureInstanceId=featureInstanceId
|
||||
featureInstanceId=featureInstanceId,
|
||||
interface=interfaceDbChat,
|
||||
)
|
||||
|
||||
# Create chatbot instance with config for dynamic tool configuration
|
||||
|
|
@ -1338,7 +1395,8 @@ async def _processChatbotMessageLangGraph(
|
|||
system_prompt=system_prompt,
|
||||
workflow_id=workflowId,
|
||||
config=config,
|
||||
event_manager=event_manager
|
||||
event_manager=event_manager,
|
||||
planner_model=planner_model,
|
||||
)
|
||||
|
||||
# Emit synthetic status for real-time UI feedback
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from typing import Dict, Any, List, Optional, Union
|
|||
from passlib.context import CryptContext
|
||||
import uuid
|
||||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
||||
from modules.interfaces.interfaceBootstrap import initBootstrap
|
||||
|
|
@ -144,8 +144,7 @@ class AppObjects:
|
|||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||||
|
||||
# Create database connector directly
|
||||
self.db = DatabaseConnector(
|
||||
self.db = _get_cached_connector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import hashlib
|
|||
import math
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||||
from modules.security.rbac import RbacClass
|
||||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
|
|
@ -131,8 +131,7 @@ class ComponentObjects:
|
|||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||||
|
||||
# Create database connector directly
|
||||
self.db = DatabaseConnector(
|
||||
self.db = _get_cached_connector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ Multi-Tenant Design:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role
|
||||
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
|
||||
from modules.datamodels.datamodelMembership import (
|
||||
|
|
@ -163,6 +163,52 @@ class RbacClass:
|
|||
|
||||
return permissions
|
||||
|
||||
def checkResourceAccessBulk(
|
||||
self,
|
||||
user: User,
|
||||
resourcePaths: List[str],
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Check view access for multiple RESOURCE paths in one pass.
|
||||
Uses same logic as getUserPermissions but batches DB access.
|
||||
Returns {path: has_view}.
|
||||
"""
|
||||
result = {p: False for p in resourcePaths}
|
||||
if not resourcePaths:
|
||||
return result
|
||||
# SysAdmin bypass
|
||||
if hasattr(user, "isSysAdmin") and user.isSysAdmin:
|
||||
return {p: True for p in resourcePaths}
|
||||
roleIds = self._getRoleIdsForUser(user, mandateId, featureInstanceId)
|
||||
if not roleIds:
|
||||
return result
|
||||
rulesWithPriority = self._getRulesForRoleIds(
|
||||
roleIds, AccessRuleContext.RESOURCE, mandateId, featureInstanceId
|
||||
)
|
||||
for path in resourcePaths:
|
||||
rolePermissions = {}
|
||||
for priority, rule in rulesWithPriority:
|
||||
if not self._ruleMatchesItem(rule, path):
|
||||
continue
|
||||
roleId = rule.roleId
|
||||
itemSpecificity = self._getItemSpecificity(rule, path)
|
||||
if roleId not in rolePermissions:
|
||||
rolePermissions[roleId] = (priority, itemSpecificity, rule)
|
||||
else:
|
||||
existingPriority, existingSpecificity, _ = rolePermissions[roleId]
|
||||
if priority > existingPriority or (
|
||||
priority == existingPriority and itemSpecificity > existingSpecificity
|
||||
):
|
||||
rolePermissions[roleId] = (priority, itemSpecificity, rule)
|
||||
highestPriority = max((p for p, _, _ in rolePermissions.values()), default=0)
|
||||
for _, (priority, _, rule) in rolePermissions.items():
|
||||
if priority >= highestPriority and rule.view:
|
||||
result[path] = True
|
||||
break
|
||||
return result
|
||||
|
||||
def _getRoleIdsForUser(
|
||||
self,
|
||||
user: User,
|
||||
|
|
|
|||
Loading…
Reference in a new issue