# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Dynamic model registry that collects models from all AI connectors. Implements plugin-like architecture for connector discovery. """ import logging import importlib import os import time from typing import Dict, List, Optional, Any, Tuple from modules.datamodels.datamodelAi import AiModel from .aicoreBase import BaseConnectorAi from modules.datamodels.datamodelUam import User from modules.security.rbacHelpers import checkResourceAccess from modules.security.rbac import RbacClass from modules.connectors.connectorDbPostgre import DatabaseConnector logger = logging.getLogger(__name__) # TODO TESTING: Override maxTokens for all models during testing # Set to None to disable override, or set to an integer (e.g., 20000) to override all models TESTING_MAX_TOKENS_OVERRIDE: Optional[int] = None # TODO TESTING: Set to None to disable class ModelRegistry: """Dynamic registry for AI models from all connectors.""" def __init__(self): self._models: Dict[str, AiModel] = {} self._connectors: Dict[str, BaseConnectorAi] = {} self._lastRefresh: Optional[float] = None self._refreshInterval: float = 300.0 # 5 minutes self._connectorsInitialized: bool = False self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts) self._getAvailableModelsCacheTtl: float = 30.0 # seconds def registerConnector(self, connector: BaseConnectorAi): """Register a connector and collect its models.""" connectorType = connector.getConnectorType() # If connector already registered, skip re-registration to avoid duplicate models if connectorType in self._connectors: logger.debug(f"Connector {connectorType} already registered, skipping re-registration") return self._connectors[connectorType] = connector # Collect models from this connector try: models = connector.getCachedModels() for model in models: # Validate displayName uniqueness if model.displayName in self._models: existingModel = self._models[model.displayName] errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connectorType}). displayName must be unique." logger.error(errorMsg) raise ValueError(errorMsg) # TODO TESTING: Override maxTokens if testing override is enabled if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE: originalMaxTokens = model.maxTokens model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}") # Use displayName as the key (must be unique) self._models[model.displayName] = model logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}") except Exception as e: logger.error(f"Failed to register models from {connectorType}: {e}") raise def discoverConnectors(self) -> List[BaseConnectorAi]: """Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use.""" if self._discoveredConnectorsCache is not None: return self._discoveredConnectorsCache connectors = [] connectorDir = os.path.dirname(__file__) # Scan for connector files for filename in os.listdir(connectorDir): if filename.startswith('aicorePlugin') and filename.endswith('.py'): moduleName = filename[:-3] # Remove .py extension try: # Import the module module = importlib.import_module(f'modules.aicore.{moduleName}') # Find connector classes (classes that inherit from BaseConnectorAi) for attrName in dir(module): attr = getattr(module, attrName) if (isinstance(attr, type) and issubclass(attr, BaseConnectorAi) and attr != BaseConnectorAi): # Instantiate the connector connector = attr() connectors.append(connector) logger.info(f"Discovered connector: {connector.getConnectorType()}") except Exception as e: logger.warning(f"Failed to discover connector from {filename}: {e}") self._discoveredConnectorsCache = connectors return connectors def ensureConnectorsRegistered(self): """Register connectors once to avoid per-request discovery.""" if self._connectorsInitialized: return discovered = self.discoverConnectors() for connector in discovered: self.registerConnector(connector) self._connectorsInitialized = True def refreshModels(self, force: bool = False): """Refresh models from all registered connectors.""" import time self.ensureConnectorsRegistered() currentTime = time.time() # Check if refresh is needed if (not force and self._lastRefresh is not None and currentTime - self._lastRefresh < self._refreshInterval): return logger.info("Refreshing model registry...") # Clear existing models self._models.clear() # Re-register all connectors for connector in self._connectors.values(): try: connector.clearCache() # Clear connector cache models = connector.getCachedModels() for model in models: # Validate displayName uniqueness if model.displayName in self._models: existingModel = self._models[model.displayName] errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connector.getConnectorType()}). displayName must be unique." logger.error(errorMsg) raise ValueError(errorMsg) # TODO TESTING: Override maxTokens if testing override is enabled if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE: originalMaxTokens = model.maxTokens model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}") # Use displayName as the key (must be unique) self._models[model.displayName] = model except Exception as e: logger.error(f"Failed to refresh models from {connector.getConnectorType()}: {e}") raise self._lastRefresh = currentTime logger.info(f"Model registry refreshed: {len(self._models)} models available") def getModel(self, displayName: str) -> Optional[AiModel]: """Get a specific model by displayName (displayName must be unique).""" self.refreshModels() return self._models.get(displayName) def getModels(self) -> List[AiModel]: """Get all available models.""" self.refreshModels() return list(self._models.values()) def getModelsByConnector(self, connectorType: str) -> List[AiModel]: """Get models from a specific connector.""" self.refreshModels() return [model for model in self._models.values() if model.connectorType == connectorType] def getModelsByPriority(self, priority: str) -> List[AiModel]: """Get models that have a specific priority.""" self.refreshModels() return [model for model in self._models.values() if model.priority == priority] def getAvailableModels( self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None ) -> List[AiModel]: """Get only available models, optionally filtered by RBAC permissions. Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call. Args: currentUser: Optional user object for RBAC filtering rbacInstance: Optional RBAC instance for permission checks mandateId: Optional mandate context for faster RBAC (loads fewer roles) featureInstanceId: Optional feature instance for RBAC context Returns: List of available models (filtered by RBAC if user provided) """ self.refreshModels() cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0) now = time.time() if cache_key in self._getAvailableModelsCache: cached_models, cached_ts = self._getAvailableModelsCache[cache_key] if now - cached_ts < self._getAvailableModelsCacheTtl: logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...") return cached_models allModels = list(self._models.values()) availableModels = [model for model in allModels if model.isAvailable] # Apply RBAC filtering if user and RBAC instance provided (batch check for performance) if currentUser and rbacInstance: availableModels = self._filterModelsByRbac( availableModels, currentUser, rbacInstance, mandateId, featureInstanceId ) self._getAvailableModelsCache[cache_key] = (availableModels, now) # Prune expired entries to avoid unbounded growth expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl] for k in expired: del self._getAvailableModelsCache[k] unavailableCount = len(allModels) - len(availableModels) if unavailableCount > 0: unavailableModels = [m.name for m in allModels if not m.isAvailable] logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}") logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}") return availableModels def _filterModelsByRbac( self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None ) -> List[AiModel]: """Filter models based on RBAC permissions. Uses bulk check for performance.""" paths = [] model_paths = {} # model -> (connector_path, model_path) for model in models: connector_path = f"ai.model.{model.connectorType}" model_path = f"ai.model.{model.connectorType}.{model.displayName}" paths.extend([connector_path, model_path]) model_paths[id(model)] = (connector_path, model_path) # Single bulk RBAC call instead of 2*N per-model calls access = rbacInstance.checkResourceAccessBulk( currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId ) filteredModels = [] for model in models: connector_path, model_path = model_paths[id(model)] if access.get(connector_path, False) or access.get(model_path, False): filteredModels.append(model) else: logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})") return filteredModels def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]: """Get a specific model by displayName, optionally checking RBAC permissions. Args: displayName: Model display name currentUser: Optional user object for RBAC check rbacInstance: Optional RBAC instance for permission check Returns: Model if found and user has access (or if no user provided), None otherwise """ self.refreshModels() model = self._models.get(displayName) if not model: return None # Check RBAC permission if user provided if currentUser and rbacInstance: connectorResourcePath = f"ai.model.{model.connectorType}" modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}" hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath) hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath) if not (hasConnectorAccess or hasModelAccess): logger.warning(f"User {currentUser.username} does not have access to model {displayName}") return None return model def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]: """Get the connector instance for a specific model by displayName.""" model = self.getModel(displayName) if model: return self._connectors.get(model.connectorType) return None def getModelStats(self) -> Dict[str, Any]: """Get statistics about the model registry.""" self.refreshModels() stats = { "totalModels": len(self._models), "availableModels": len([m for m in self._models.values() if m.isAvailable]), "connectors": len(self._connectors), "byConnector": {}, "byCapability": {}, "byPriority": {} } # Count by connector for model in self._models.values(): connector = model.connectorType if connector not in stats["byConnector"]: stats["byConnector"][connector] = 0 stats["byConnector"][connector] += 1 # Count by capability for model in self._models.values(): for capability in model.capabilities: if capability not in stats["byCapability"]: stats["byCapability"][capability] = 0 stats["byCapability"][capability] += 1 # Count by priority for model in self._models.values(): priority = model.priority if priority not in stats["byPriority"]: stats["byPriority"][priority] = 0 stats["byPriority"][priority] += 1 return stats # Global registry instance modelRegistry = ModelRegistry() # Eager pre-warm on first import: ensures connectors are ready in this process. # Critical for chatbot performance — avoids 4–8 s latency on first request. # Runs when this module is first imported (lifespan or first chatbot request). def _eager_prewarm() -> None: try: modelRegistry.ensureConnectorsRegistered() modelRegistry.refreshModels(force=True) logger.info("AI connectors and model registry pre-warmed (module load)") except Exception as e: logger.warning(f"AI eager pre-warm skipped: {e}") _eager_prewarm()