# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Dynamic model registry that collects models from all AI connectors. Implements plugin-like architecture for connector discovery. """ import logging import importlib import os import time import threading from typing import Dict, List, Optional, Any, Tuple from modules.datamodels.datamodelAi import AiModel from .aicoreBase import BaseConnectorAi from modules.datamodels.datamodelUam import User from modules.security.rbacHelpers import checkResourceAccess from modules.security.rbac import RbacClass from modules.connectors.connectorDbPostgre import DatabaseConnector logger = logging.getLogger(__name__) # TODO TESTING: Override maxTokens for all models during testing # Set to None to disable override, or set to an integer (e.g., 20000) to override all models TESTING_MAX_TOKENS_OVERRIDE: Optional[int] = None # TODO TESTING: Set to None to disable class ModelRegistry: """Dynamic registry for AI models from all connectors.""" def __init__(self): self._models: Dict[str, AiModel] = {} self._connectors: Dict[str, BaseConnectorAi] = {} self._lastRefresh: Optional[float] = None self._refreshInterval: float = 300.0 # 5 minutes self._refreshLock = threading.Lock() self._connectorsInitialized: bool = False self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts) self._getAvailableModelsCacheTtl: float = 30.0 # seconds def _addModelToDict(self, model: AiModel, connectorType: str, target: Dict[str, AiModel]): """Add model to a dict, tolerating benign re-adds from the same connector.""" if model.displayName in target: existing = target[model.displayName] if existing.name == model.name and existing.connectorType == model.connectorType: logger.debug(f"Skipping duplicate model '{model.displayName}' from same connector {connectorType}") return raise ValueError( f"displayName conflict '{model.displayName}': " f"existing name='{existing.name}' (connector: {existing.connectorType}), " f"new name='{model.name}' (connector: {connectorType})" ) if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE: originalMaxTokens = model.maxTokens model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}") target[model.displayName] = model logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}") def _addModel(self, model: AiModel, connectorType: str): """Convenience wrapper for adding to self._models.""" self._addModelToDict(model, connectorType, self._models) def registerConnector(self, connector: BaseConnectorAi): """Register a connector and collect its models.""" connectorType = connector.getConnectorType() # If connector already registered, skip re-registration to avoid duplicate models if connectorType in self._connectors: logger.debug(f"Connector {connectorType} already registered, skipping re-registration") return self._connectors[connectorType] = connector try: models = connector.getCachedModels() for model in models: self._addModel(model, connectorType) except Exception as e: logger.error(f"Failed to register models from {connectorType}: {e}") raise def discoverConnectors(self) -> List[BaseConnectorAi]: """Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use.""" if self._discoveredConnectorsCache is not None: return self._discoveredConnectorsCache connectors = [] connectorDir = os.path.dirname(__file__) # Scan for connector files for filename in os.listdir(connectorDir): if filename.startswith('aicorePlugin') and filename.endswith('.py'): moduleName = filename[:-3] # Remove .py extension try: # Import the module module = importlib.import_module(f'modules.aicore.{moduleName}') # Find connector classes (classes that inherit from BaseConnectorAi) for attrName in dir(module): attr = getattr(module, attrName) if (isinstance(attr, type) and issubclass(attr, BaseConnectorAi) and attr != BaseConnectorAi): # Instantiate the connector connector = attr() connectors.append(connector) logger.info(f"Discovered connector: {connector.getConnectorType()}") except Exception as e: logger.warning(f"Failed to discover connector from {filename}: {e}") self._discoveredConnectorsCache = connectors return connectors def ensureConnectorsRegistered(self): """Register connectors once to avoid per-request discovery.""" if self._connectorsInitialized: return discovered = self.discoverConnectors() for connector in discovered: self.registerConnector(connector) self._connectorsInitialized = True def refreshModels(self, force: bool = False): """Refresh models from all registered connectors. Thread-safe via _refreshLock.""" self.ensureConnectorsRegistered() currentTime = time.time() if (not force and self._lastRefresh is not None and currentTime - self._lastRefresh < self._refreshInterval): return if not self._refreshLock.acquire(blocking=False): logger.debug("refreshModels already running in another thread, skipping") return try: logger.info("Refreshing model registry...") newModels: Dict[str, AiModel] = {} for connector in self._connectors.values(): connectorType = connector.getConnectorType() try: connector.clearCache() models = connector.getCachedModels() for model in models: self._addModelToDict(model, connectorType, newModels) except Exception as e: logger.error(f"Failed to refresh models from {connectorType}: {e}") raise self._models = newModels self._lastRefresh = time.time() logger.info(f"Model registry refreshed: {len(self._models)} models available") finally: self._refreshLock.release() def getModel(self, displayName: str) -> Optional[AiModel]: """Get a specific model by displayName (displayName must be unique).""" self.refreshModels() return self._models.get(displayName) def getModels(self) -> List[AiModel]: """Get all available models.""" self.refreshModels() return list(self._models.values()) def getModelsByConnector(self, connectorType: str) -> List[AiModel]: """Get models from a specific connector.""" self.refreshModels() return [model for model in self._models.values() if model.connectorType == connectorType] def getModelsByPriority(self, priority: str) -> List[AiModel]: """Get models that have a specific priority.""" self.refreshModels() return [model for model in self._models.values() if model.priority == priority] def getAvailableModels( self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None ) -> List[AiModel]: """Get only available models, optionally filtered by RBAC permissions. Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call. Args: currentUser: Optional user object for RBAC filtering rbacInstance: Optional RBAC instance for permission checks mandateId: Optional mandate context for faster RBAC (loads fewer roles) featureInstanceId: Optional feature instance for RBAC context Returns: List of available models (filtered by RBAC if user provided) """ self.refreshModels() cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0) now = time.time() if cache_key in self._getAvailableModelsCache: cached_models, cached_ts = self._getAvailableModelsCache[cache_key] if now - cached_ts < self._getAvailableModelsCacheTtl: logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...") return cached_models allModels = list(self._models.values()) availableModels = [model for model in allModels if model.isAvailable] # Apply RBAC filtering if user and RBAC instance provided (batch check for performance) if currentUser and rbacInstance: availableModels = self._filterModelsByRbac( availableModels, currentUser, rbacInstance, mandateId, featureInstanceId ) self._getAvailableModelsCache[cache_key] = (availableModels, now) # Prune expired entries to avoid unbounded growth expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl] for k in expired: del self._getAvailableModelsCache[k] unavailableCount = len(allModels) - len(availableModels) if unavailableCount > 0: unavailableModels = [m.name for m in allModels if not m.isAvailable] logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}") logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}") return availableModels def _filterModelsByRbac( self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None ) -> List[AiModel]: """Filter models based on RBAC permissions. Uses bulk check for performance.""" paths = [] model_paths = {} # model -> (connector_path, model_path) for model in models: connector_path = f"ai.model.{model.connectorType}" model_path = f"ai.model.{model.connectorType}.{model.displayName}" paths.extend([connector_path, model_path]) model_paths[id(model)] = (connector_path, model_path) # Single bulk RBAC call instead of 2*N per-model calls access = rbacInstance.checkResourceAccessBulk( currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId ) filteredModels = [] for model in models: connector_path, model_path = model_paths[id(model)] if access.get(connector_path, False) or access.get(model_path, False): filteredModels.append(model) else: logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})") return filteredModels def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]: """Get a specific model by displayName, optionally checking RBAC permissions. Args: displayName: Model display name currentUser: Optional user object for RBAC check rbacInstance: Optional RBAC instance for permission check Returns: Model if found and user has access (or if no user provided), None otherwise """ self.refreshModels() model = self._models.get(displayName) if not model: return None # Check RBAC permission if user provided if currentUser and rbacInstance: connectorResourcePath = f"ai.model.{model.connectorType}" modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}" hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath) hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath) if not (hasConnectorAccess or hasModelAccess): logger.warning(f"User {currentUser.username} does not have access to model {displayName}") return None return model def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]: """Get the connector instance for a specific model by displayName.""" model = self.getModel(displayName) if model: return self._connectors.get(model.connectorType) return None def getModelStats(self) -> Dict[str, Any]: """Get statistics about the model registry.""" self.refreshModels() stats = { "totalModels": len(self._models), "availableModels": len([m for m in self._models.values() if m.isAvailable]), "connectors": len(self._connectors), "byConnector": {}, "byCapability": {}, "byPriority": {} } # Count by connector for model in self._models.values(): connector = model.connectorType if connector not in stats["byConnector"]: stats["byConnector"][connector] = 0 stats["byConnector"][connector] += 1 # Count by capability for model in self._models.values(): for capability in model.capabilities: if capability not in stats["byCapability"]: stats["byCapability"][capability] = 0 stats["byCapability"][capability] += 1 # Count by priority for model in self._models.values(): priority = model.priority if priority not in stats["byPriority"]: stats["byPriority"][priority] = 0 stats["byPriority"][priority] += 1 return stats # Global registry instance modelRegistry = ModelRegistry() # Eager pre-warm on first import: ensures connectors are ready in this process. # Critical for chatbot performance — avoids 4–8 s latency on first request. # Runs when this module is first imported (lifespan or first chatbot request). def _eager_prewarm() -> None: try: modelRegistry.ensureConnectorsRegistered() modelRegistry.refreshModels(force=True) logger.info("AI connectors and model registry pre-warmed (module load)") except Exception as e: logger.warning(f"AI eager pre-warm skipped: {e}") _eager_prewarm()