gateway/modules/aicore/aicoreModelRegistry.py
Ida Dittrich 6dc2afafb9 fix:performance improvements
- app.py: Pre-warm AI connectors at module load and in lifespan
- aicoreModelRegistry.py: Connector discovery cache, getAvailableModels cache, bulk RBAC, eager prewarm
- connectorDbPostgre.py: Connector cache, contextvars for userId, eviction (max 32)
- chatbot: Uses _get_cached_connector, Service center integration, BillingService exceptions, BillingService exceptions instead of direct imports
- interfaceDbApp.py: Uses _get_cached_connector
- interfaceDbManagement.py: Uses _get_cached_connector
- security/rbac.py: Adds checkResourceAccessBulk
2026-03-06 13:46:54 +01:00

355 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Dynamic model registry that collects models from all AI connectors.
Implements plugin-like architecture for connector discovery.
"""
import logging
import importlib
import os
import time
from typing import Dict, List, Optional, Any, Tuple
from modules.datamodels.datamodelAi import AiModel
from .aicoreBase import BaseConnectorAi
from modules.datamodels.datamodelUam import User
from modules.security.rbacHelpers import checkResourceAccess
from modules.security.rbac import RbacClass
from modules.connectors.connectorDbPostgre import DatabaseConnector
logger = logging.getLogger(__name__)
# TODO TESTING: Override maxTokens for all models during testing
# Set to None to disable override, or set to an integer (e.g., 20000) to override all models
TESTING_MAX_TOKENS_OVERRIDE: Optional[int] = None # TODO TESTING: Set to None to disable
class ModelRegistry:
"""Dynamic registry for AI models from all connectors."""
def __init__(self):
self._models: Dict[str, AiModel] = {}
self._connectors: Dict[str, BaseConnectorAi] = {}
self._lastRefresh: Optional[float] = None
self._refreshInterval: float = 300.0 # 5 minutes
self._connectorsInitialized: bool = False
self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call
self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts)
self._getAvailableModelsCacheTtl: float = 30.0 # seconds
def registerConnector(self, connector: BaseConnectorAi):
"""Register a connector and collect its models."""
connectorType = connector.getConnectorType()
# If connector already registered, skip re-registration to avoid duplicate models
if connectorType in self._connectors:
logger.debug(f"Connector {connectorType} already registered, skipping re-registration")
return
self._connectors[connectorType] = connector
# Collect models from this connector
try:
models = connector.getCachedModels()
for model in models:
# Validate displayName uniqueness
if model.displayName in self._models:
existingModel = self._models[model.displayName]
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connectorType}). displayName must be unique."
logger.error(errorMsg)
raise ValueError(errorMsg)
# TODO TESTING: Override maxTokens if testing override is enabled
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
originalMaxTokens = model.maxTokens
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
# Use displayName as the key (must be unique)
self._models[model.displayName] = model
logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}")
except Exception as e:
logger.error(f"Failed to register models from {connectorType}: {e}")
raise
def discoverConnectors(self) -> List[BaseConnectorAi]:
"""Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use."""
if self._discoveredConnectorsCache is not None:
return self._discoveredConnectorsCache
connectors = []
connectorDir = os.path.dirname(__file__)
# Scan for connector files
for filename in os.listdir(connectorDir):
if filename.startswith('aicorePlugin') and filename.endswith('.py'):
moduleName = filename[:-3] # Remove .py extension
try:
# Import the module
module = importlib.import_module(f'modules.aicore.{moduleName}')
# Find connector classes (classes that inherit from BaseConnectorAi)
for attrName in dir(module):
attr = getattr(module, attrName)
if (isinstance(attr, type) and
issubclass(attr, BaseConnectorAi) and
attr != BaseConnectorAi):
# Instantiate the connector
connector = attr()
connectors.append(connector)
logger.info(f"Discovered connector: {connector.getConnectorType()}")
except Exception as e:
logger.warning(f"Failed to discover connector from {filename}: {e}")
self._discoveredConnectorsCache = connectors
return connectors
def ensureConnectorsRegistered(self):
"""Register connectors once to avoid per-request discovery."""
if self._connectorsInitialized:
return
discovered = self.discoverConnectors()
for connector in discovered:
self.registerConnector(connector)
self._connectorsInitialized = True
def refreshModels(self, force: bool = False):
"""Refresh models from all registered connectors."""
import time
self.ensureConnectorsRegistered()
currentTime = time.time()
# Check if refresh is needed
if (not force and
self._lastRefresh is not None and
currentTime - self._lastRefresh < self._refreshInterval):
return
logger.info("Refreshing model registry...")
# Clear existing models
self._models.clear()
# Re-register all connectors
for connector in self._connectors.values():
try:
connector.clearCache() # Clear connector cache
models = connector.getCachedModels()
for model in models:
# Validate displayName uniqueness
if model.displayName in self._models:
existingModel = self._models[model.displayName]
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connector.getConnectorType()}). displayName must be unique."
logger.error(errorMsg)
raise ValueError(errorMsg)
# TODO TESTING: Override maxTokens if testing override is enabled
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
originalMaxTokens = model.maxTokens
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
# Use displayName as the key (must be unique)
self._models[model.displayName] = model
except Exception as e:
logger.error(f"Failed to refresh models from {connector.getConnectorType()}: {e}")
raise
self._lastRefresh = currentTime
logger.info(f"Model registry refreshed: {len(self._models)} models available")
def getModel(self, displayName: str) -> Optional[AiModel]:
"""Get a specific model by displayName (displayName must be unique)."""
self.refreshModels()
return self._models.get(displayName)
def getModels(self) -> List[AiModel]:
"""Get all available models."""
self.refreshModels()
return list(self._models.values())
def getModelsByConnector(self, connectorType: str) -> List[AiModel]:
"""Get models from a specific connector."""
self.refreshModels()
return [model for model in self._models.values() if model.connectorType == connectorType]
def getModelsByPriority(self, priority: str) -> List[AiModel]:
"""Get models that have a specific priority."""
self.refreshModels()
return [model for model in self._models.values() if model.priority == priority]
def getAvailableModels(
self,
currentUser: Optional[User] = None,
rbacInstance: Optional[RbacClass] = None,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None
) -> List[AiModel]:
"""Get only available models, optionally filtered by RBAC permissions.
Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call.
Args:
currentUser: Optional user object for RBAC filtering
rbacInstance: Optional RBAC instance for permission checks
mandateId: Optional mandate context for faster RBAC (loads fewer roles)
featureInstanceId: Optional feature instance for RBAC context
Returns:
List of available models (filtered by RBAC if user provided)
"""
self.refreshModels()
cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0)
now = time.time()
if cache_key in self._getAvailableModelsCache:
cached_models, cached_ts = self._getAvailableModelsCache[cache_key]
if now - cached_ts < self._getAvailableModelsCacheTtl:
logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...")
return cached_models
allModels = list(self._models.values())
availableModels = [model for model in allModels if model.isAvailable]
# Apply RBAC filtering if user and RBAC instance provided (batch check for performance)
if currentUser and rbacInstance:
availableModels = self._filterModelsByRbac(
availableModels, currentUser, rbacInstance, mandateId, featureInstanceId
)
self._getAvailableModelsCache[cache_key] = (availableModels, now)
# Prune expired entries to avoid unbounded growth
expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl]
for k in expired:
del self._getAvailableModelsCache[k]
unavailableCount = len(allModels) - len(availableModels)
if unavailableCount > 0:
unavailableModels = [m.name for m in allModels if not m.isAvailable]
logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}")
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
return availableModels
def _filterModelsByRbac(
self,
models: List[AiModel],
currentUser: User,
rbacInstance: RbacClass,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None
) -> List[AiModel]:
"""Filter models based on RBAC permissions. Uses bulk check for performance."""
paths = []
model_paths = {} # model -> (connector_path, model_path)
for model in models:
connector_path = f"ai.model.{model.connectorType}"
model_path = f"ai.model.{model.connectorType}.{model.displayName}"
paths.extend([connector_path, model_path])
model_paths[id(model)] = (connector_path, model_path)
# Single bulk RBAC call instead of 2*N per-model calls
access = rbacInstance.checkResourceAccessBulk(
currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId
)
filteredModels = []
for model in models:
connector_path, model_path = model_paths[id(model)]
if access.get(connector_path, False) or access.get(model_path, False):
filteredModels.append(model)
else:
logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})")
return filteredModels
def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]:
"""Get a specific model by displayName, optionally checking RBAC permissions.
Args:
displayName: Model display name
currentUser: Optional user object for RBAC check
rbacInstance: Optional RBAC instance for permission check
Returns:
Model if found and user has access (or if no user provided), None otherwise
"""
self.refreshModels()
model = self._models.get(displayName)
if not model:
return None
# Check RBAC permission if user provided
if currentUser and rbacInstance:
connectorResourcePath = f"ai.model.{model.connectorType}"
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
if not (hasConnectorAccess or hasModelAccess):
logger.warning(f"User {currentUser.username} does not have access to model {displayName}")
return None
return model
def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]:
"""Get the connector instance for a specific model by displayName."""
model = self.getModel(displayName)
if model:
return self._connectors.get(model.connectorType)
return None
def getModelStats(self) -> Dict[str, Any]:
"""Get statistics about the model registry."""
self.refreshModels()
stats = {
"totalModels": len(self._models),
"availableModels": len([m for m in self._models.values() if m.isAvailable]),
"connectors": len(self._connectors),
"byConnector": {},
"byCapability": {},
"byPriority": {}
}
# Count by connector
for model in self._models.values():
connector = model.connectorType
if connector not in stats["byConnector"]:
stats["byConnector"][connector] = 0
stats["byConnector"][connector] += 1
# Count by capability
for model in self._models.values():
for capability in model.capabilities:
if capability not in stats["byCapability"]:
stats["byCapability"][capability] = 0
stats["byCapability"][capability] += 1
# Count by priority
for model in self._models.values():
priority = model.priority
if priority not in stats["byPriority"]:
stats["byPriority"][priority] = 0
stats["byPriority"][priority] += 1
return stats
# Global registry instance
modelRegistry = ModelRegistry()
# Eager pre-warm on first import: ensures connectors are ready in this process.
# Critical for chatbot performance — avoids 48 s latency on first request.
# Runs when this module is first imported (lifespan or first chatbot request).
def _eager_prewarm() -> None:
try:
modelRegistry.ensureConnectorsRegistered()
modelRegistry.refreshModels(force=True)
logger.info("AI connectors and model registry pre-warmed (module load)")
except Exception as e:
logger.warning(f"AI eager pre-warm skipped: {e}")
_eager_prewarm()