- app.py: Pre-warm AI connectors at module load and in lifespan - aicoreModelRegistry.py: Connector discovery cache, getAvailableModels cache, bulk RBAC, eager prewarm - connectorDbPostgre.py: Connector cache, contextvars for userId, eviction (max 32) - chatbot: Uses _get_cached_connector, Service center integration, BillingService exceptions, BillingService exceptions instead of direct imports - interfaceDbApp.py: Uses _get_cached_connector - interfaceDbManagement.py: Uses _get_cached_connector - security/rbac.py: Adds checkResourceAccessBulk
355 lines
16 KiB
Python
355 lines
16 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
"""
|
||
Dynamic model registry that collects models from all AI connectors.
|
||
Implements plugin-like architecture for connector discovery.
|
||
"""
|
||
|
||
import logging
|
||
import importlib
|
||
import os
|
||
import time
|
||
from typing import Dict, List, Optional, Any, Tuple
|
||
from modules.datamodels.datamodelAi import AiModel
|
||
from .aicoreBase import BaseConnectorAi
|
||
from modules.datamodels.datamodelUam import User
|
||
from modules.security.rbacHelpers import checkResourceAccess
|
||
from modules.security.rbac import RbacClass
|
||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# TODO TESTING: Override maxTokens for all models during testing
|
||
# Set to None to disable override, or set to an integer (e.g., 20000) to override all models
|
||
TESTING_MAX_TOKENS_OVERRIDE: Optional[int] = None # TODO TESTING: Set to None to disable
|
||
|
||
class ModelRegistry:
|
||
"""Dynamic registry for AI models from all connectors."""
|
||
|
||
def __init__(self):
|
||
self._models: Dict[str, AiModel] = {}
|
||
self._connectors: Dict[str, BaseConnectorAi] = {}
|
||
self._lastRefresh: Optional[float] = None
|
||
self._refreshInterval: float = 300.0 # 5 minutes
|
||
self._connectorsInitialized: bool = False
|
||
self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call
|
||
self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts)
|
||
self._getAvailableModelsCacheTtl: float = 30.0 # seconds
|
||
|
||
def registerConnector(self, connector: BaseConnectorAi):
|
||
"""Register a connector and collect its models."""
|
||
connectorType = connector.getConnectorType()
|
||
|
||
# If connector already registered, skip re-registration to avoid duplicate models
|
||
if connectorType in self._connectors:
|
||
logger.debug(f"Connector {connectorType} already registered, skipping re-registration")
|
||
return
|
||
|
||
self._connectors[connectorType] = connector
|
||
|
||
# Collect models from this connector
|
||
try:
|
||
models = connector.getCachedModels()
|
||
for model in models:
|
||
# Validate displayName uniqueness
|
||
if model.displayName in self._models:
|
||
existingModel = self._models[model.displayName]
|
||
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connectorType}). displayName must be unique."
|
||
logger.error(errorMsg)
|
||
raise ValueError(errorMsg)
|
||
|
||
# TODO TESTING: Override maxTokens if testing override is enabled
|
||
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
|
||
originalMaxTokens = model.maxTokens
|
||
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
|
||
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
|
||
|
||
# Use displayName as the key (must be unique)
|
||
self._models[model.displayName] = model
|
||
logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to register models from {connectorType}: {e}")
|
||
raise
|
||
|
||
def discoverConnectors(self) -> List[BaseConnectorAi]:
|
||
"""Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use."""
|
||
if self._discoveredConnectorsCache is not None:
|
||
return self._discoveredConnectorsCache
|
||
|
||
connectors = []
|
||
connectorDir = os.path.dirname(__file__)
|
||
|
||
# Scan for connector files
|
||
for filename in os.listdir(connectorDir):
|
||
if filename.startswith('aicorePlugin') and filename.endswith('.py'):
|
||
moduleName = filename[:-3] # Remove .py extension
|
||
|
||
try:
|
||
# Import the module
|
||
module = importlib.import_module(f'modules.aicore.{moduleName}')
|
||
|
||
# Find connector classes (classes that inherit from BaseConnectorAi)
|
||
for attrName in dir(module):
|
||
attr = getattr(module, attrName)
|
||
if (isinstance(attr, type) and
|
||
issubclass(attr, BaseConnectorAi) and
|
||
attr != BaseConnectorAi):
|
||
|
||
# Instantiate the connector
|
||
connector = attr()
|
||
connectors.append(connector)
|
||
logger.info(f"Discovered connector: {connector.getConnectorType()}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to discover connector from {filename}: {e}")
|
||
|
||
self._discoveredConnectorsCache = connectors
|
||
return connectors
|
||
|
||
def ensureConnectorsRegistered(self):
|
||
"""Register connectors once to avoid per-request discovery."""
|
||
if self._connectorsInitialized:
|
||
return
|
||
discovered = self.discoverConnectors()
|
||
for connector in discovered:
|
||
self.registerConnector(connector)
|
||
self._connectorsInitialized = True
|
||
|
||
def refreshModels(self, force: bool = False):
|
||
"""Refresh models from all registered connectors."""
|
||
import time
|
||
|
||
self.ensureConnectorsRegistered()
|
||
|
||
currentTime = time.time()
|
||
|
||
# Check if refresh is needed
|
||
if (not force and
|
||
self._lastRefresh is not None and
|
||
currentTime - self._lastRefresh < self._refreshInterval):
|
||
return
|
||
|
||
logger.info("Refreshing model registry...")
|
||
|
||
# Clear existing models
|
||
self._models.clear()
|
||
|
||
# Re-register all connectors
|
||
for connector in self._connectors.values():
|
||
try:
|
||
connector.clearCache() # Clear connector cache
|
||
models = connector.getCachedModels()
|
||
for model in models:
|
||
# Validate displayName uniqueness
|
||
if model.displayName in self._models:
|
||
existingModel = self._models[model.displayName]
|
||
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connector.getConnectorType()}). displayName must be unique."
|
||
logger.error(errorMsg)
|
||
raise ValueError(errorMsg)
|
||
|
||
# TODO TESTING: Override maxTokens if testing override is enabled
|
||
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
|
||
originalMaxTokens = model.maxTokens
|
||
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
|
||
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
|
||
|
||
# Use displayName as the key (must be unique)
|
||
self._models[model.displayName] = model
|
||
except Exception as e:
|
||
logger.error(f"Failed to refresh models from {connector.getConnectorType()}: {e}")
|
||
raise
|
||
|
||
self._lastRefresh = currentTime
|
||
logger.info(f"Model registry refreshed: {len(self._models)} models available")
|
||
|
||
def getModel(self, displayName: str) -> Optional[AiModel]:
|
||
"""Get a specific model by displayName (displayName must be unique)."""
|
||
self.refreshModels()
|
||
return self._models.get(displayName)
|
||
|
||
def getModels(self) -> List[AiModel]:
|
||
"""Get all available models."""
|
||
self.refreshModels()
|
||
return list(self._models.values())
|
||
|
||
def getModelsByConnector(self, connectorType: str) -> List[AiModel]:
|
||
"""Get models from a specific connector."""
|
||
self.refreshModels()
|
||
return [model for model in self._models.values() if model.connectorType == connectorType]
|
||
|
||
|
||
def getModelsByPriority(self, priority: str) -> List[AiModel]:
|
||
"""Get models that have a specific priority."""
|
||
self.refreshModels()
|
||
return [model for model in self._models.values() if model.priority == priority]
|
||
|
||
def getAvailableModels(
|
||
self,
|
||
currentUser: Optional[User] = None,
|
||
rbacInstance: Optional[RbacClass] = None,
|
||
mandateId: Optional[str] = None,
|
||
featureInstanceId: Optional[str] = None
|
||
) -> List[AiModel]:
|
||
"""Get only available models, optionally filtered by RBAC permissions.
|
||
Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call.
|
||
|
||
Args:
|
||
currentUser: Optional user object for RBAC filtering
|
||
rbacInstance: Optional RBAC instance for permission checks
|
||
mandateId: Optional mandate context for faster RBAC (loads fewer roles)
|
||
featureInstanceId: Optional feature instance for RBAC context
|
||
|
||
Returns:
|
||
List of available models (filtered by RBAC if user provided)
|
||
"""
|
||
self.refreshModels()
|
||
cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0)
|
||
now = time.time()
|
||
if cache_key in self._getAvailableModelsCache:
|
||
cached_models, cached_ts = self._getAvailableModelsCache[cache_key]
|
||
if now - cached_ts < self._getAvailableModelsCacheTtl:
|
||
logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...")
|
||
return cached_models
|
||
|
||
allModels = list(self._models.values())
|
||
availableModels = [model for model in allModels if model.isAvailable]
|
||
|
||
# Apply RBAC filtering if user and RBAC instance provided (batch check for performance)
|
||
if currentUser and rbacInstance:
|
||
availableModels = self._filterModelsByRbac(
|
||
availableModels, currentUser, rbacInstance, mandateId, featureInstanceId
|
||
)
|
||
|
||
self._getAvailableModelsCache[cache_key] = (availableModels, now)
|
||
# Prune expired entries to avoid unbounded growth
|
||
expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl]
|
||
for k in expired:
|
||
del self._getAvailableModelsCache[k]
|
||
|
||
unavailableCount = len(allModels) - len(availableModels)
|
||
if unavailableCount > 0:
|
||
unavailableModels = [m.name for m in allModels if not m.isAvailable]
|
||
logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}")
|
||
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
|
||
return availableModels
|
||
|
||
def _filterModelsByRbac(
|
||
self,
|
||
models: List[AiModel],
|
||
currentUser: User,
|
||
rbacInstance: RbacClass,
|
||
mandateId: Optional[str] = None,
|
||
featureInstanceId: Optional[str] = None
|
||
) -> List[AiModel]:
|
||
"""Filter models based on RBAC permissions. Uses bulk check for performance."""
|
||
paths = []
|
||
model_paths = {} # model -> (connector_path, model_path)
|
||
for model in models:
|
||
connector_path = f"ai.model.{model.connectorType}"
|
||
model_path = f"ai.model.{model.connectorType}.{model.displayName}"
|
||
paths.extend([connector_path, model_path])
|
||
model_paths[id(model)] = (connector_path, model_path)
|
||
# Single bulk RBAC call instead of 2*N per-model calls
|
||
access = rbacInstance.checkResourceAccessBulk(
|
||
currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId
|
||
)
|
||
filteredModels = []
|
||
for model in models:
|
||
connector_path, model_path = model_paths[id(model)]
|
||
if access.get(connector_path, False) or access.get(model_path, False):
|
||
filteredModels.append(model)
|
||
else:
|
||
logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})")
|
||
return filteredModels
|
||
|
||
def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]:
|
||
"""Get a specific model by displayName, optionally checking RBAC permissions.
|
||
|
||
Args:
|
||
displayName: Model display name
|
||
currentUser: Optional user object for RBAC check
|
||
rbacInstance: Optional RBAC instance for permission check
|
||
|
||
Returns:
|
||
Model if found and user has access (or if no user provided), None otherwise
|
||
"""
|
||
self.refreshModels()
|
||
model = self._models.get(displayName)
|
||
|
||
if not model:
|
||
return None
|
||
|
||
# Check RBAC permission if user provided
|
||
if currentUser and rbacInstance:
|
||
connectorResourcePath = f"ai.model.{model.connectorType}"
|
||
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
|
||
|
||
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
|
||
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
|
||
|
||
if not (hasConnectorAccess or hasModelAccess):
|
||
logger.warning(f"User {currentUser.username} does not have access to model {displayName}")
|
||
return None
|
||
|
||
return model
|
||
|
||
def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]:
|
||
"""Get the connector instance for a specific model by displayName."""
|
||
model = self.getModel(displayName)
|
||
if model:
|
||
return self._connectors.get(model.connectorType)
|
||
return None
|
||
|
||
def getModelStats(self) -> Dict[str, Any]:
|
||
"""Get statistics about the model registry."""
|
||
self.refreshModels()
|
||
|
||
stats = {
|
||
"totalModels": len(self._models),
|
||
"availableModels": len([m for m in self._models.values() if m.isAvailable]),
|
||
"connectors": len(self._connectors),
|
||
"byConnector": {},
|
||
"byCapability": {},
|
||
"byPriority": {}
|
||
}
|
||
|
||
# Count by connector
|
||
for model in self._models.values():
|
||
connector = model.connectorType
|
||
if connector not in stats["byConnector"]:
|
||
stats["byConnector"][connector] = 0
|
||
stats["byConnector"][connector] += 1
|
||
|
||
# Count by capability
|
||
for model in self._models.values():
|
||
for capability in model.capabilities:
|
||
if capability not in stats["byCapability"]:
|
||
stats["byCapability"][capability] = 0
|
||
stats["byCapability"][capability] += 1
|
||
|
||
# Count by priority
|
||
for model in self._models.values():
|
||
priority = model.priority
|
||
if priority not in stats["byPriority"]:
|
||
stats["byPriority"][priority] = 0
|
||
stats["byPriority"][priority] += 1
|
||
|
||
return stats
|
||
|
||
|
||
# Global registry instance
|
||
modelRegistry = ModelRegistry()
|
||
|
||
# Eager pre-warm on first import: ensures connectors are ready in this process.
|
||
# Critical for chatbot performance — avoids 4–8 s latency on first request.
|
||
# Runs when this module is first imported (lifespan or first chatbot request).
|
||
def _eager_prewarm() -> None:
|
||
try:
|
||
modelRegistry.ensureConnectorsRegistered()
|
||
modelRegistry.refreshModels(force=True)
|
||
logger.info("AI connectors and model registry pre-warmed (module load)")
|
||
except Exception as e:
|
||
logger.warning(f"AI eager pre-warm skipped: {e}")
|
||
|
||
|
||
_eager_prewarm()
|