# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Dynamic model registry that collects models from all AI connectors. Implements plugin-like architecture for connector discovery. """ import logging import importlib import os import time import threading from typing import Dict, List, Optional, Any, Tuple, TYPE_CHECKING from modules.datamodels.datamodelAi import AiModel from modules.datamodels.datamodelRbac import AccessRuleContext from .aicoreBase import BaseConnectorAi from modules.datamodels.datamodelUam import User from modules.connectors.connectorDbPostgre import DatabaseConnector if TYPE_CHECKING: from modules.security.rbac import RbacClass logger = logging.getLogger(__name__) # TODO TESTING: Override maxTokens for all models during testing # Set to None to disable override, or set to an integer (e.g., 20000) to override all models TESTING_MAX_TOKENS_OVERRIDE: Optional[int] = None # TODO TESTING: Set to None to disable class ModelRegistry: """Dynamic registry for AI models from all connectors.""" def __init__(self): self._models: Dict[str, AiModel] = {} self._connectors: Dict[str, BaseConnectorAi] = {} self._lastRefresh: Optional[float] = None self._refreshInterval: float = 300.0 # 5 minutes self._refreshLock = threading.Lock() self._connectorsInitialized: bool = False self._discoveredConnectorsCache: Optional[List[BaseConnectorAi]] = None # Avoid re-instantiating on every discoverConnectors() call self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts) self._getAvailableModelsCacheTtl: float = 30.0 # seconds def _addModelToDict(self, model: AiModel, connectorType: str, target: Dict[str, AiModel]): """Add model to a dict, tolerating benign re-adds from the same connector.""" if model.displayName in target: existing = target[model.displayName] if existing.name == model.name and existing.connectorType == model.connectorType: logger.debug(f"Skipping duplicate model '{model.displayName}' from same connector {connectorType}") return raise ValueError( f"displayName conflict '{model.displayName}': " f"existing name='{existing.name}' (connector: {existing.connectorType}), " f"new name='{model.name}' (connector: {connectorType})" ) if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE: originalMaxTokens = model.maxTokens model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}") target[model.displayName] = model logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}") def _addModel(self, model: AiModel, connectorType: str): """Convenience wrapper for adding to self._models.""" self._addModelToDict(model, connectorType, self._models) def registerConnector(self, connector: BaseConnectorAi): """Register a connector and collect its models.""" connectorType = connector.getConnectorType() # If connector already registered, skip re-registration to avoid duplicate models if connectorType in self._connectors: logger.debug(f"Connector {connectorType} already registered, skipping re-registration") return self._connectors[connectorType] = connector try: models = connector.getCachedModels() for model in models: self._addModel(model, connectorType) except Exception as e: logger.error(f"Failed to register models from {connectorType}: {e}") raise def discoverConnectors(self) -> List[BaseConnectorAi]: """Auto-discover connectors by scanning aicorePlugin*.py files. Cached after first call to avoid 4-8 s re-init on every use.""" if self._discoveredConnectorsCache is not None: return self._discoveredConnectorsCache connectors = [] connectorDir = os.path.dirname(__file__) # Scan for connector files for filename in os.listdir(connectorDir): if filename.startswith('aicorePlugin') and filename.endswith('.py'): moduleName = filename[:-3] # Remove .py extension try: # Import the module module = importlib.import_module(f'modules.aicore.{moduleName}') # Find connector classes (classes that inherit from BaseConnectorAi) for attrName in dir(module): attr = getattr(module, attrName) if (isinstance(attr, type) and issubclass(attr, BaseConnectorAi) and attr != BaseConnectorAi): # Instantiate the connector connector = attr() connectors.append(connector) logger.info(f"Discovered connector: {connector.getConnectorType()}") except Exception as e: logger.warning(f"Failed to discover connector from {filename}: {e}") self._discoveredConnectorsCache = connectors return connectors def ensureConnectorsRegistered(self): """Register connectors once to avoid per-request discovery.""" if self._connectorsInitialized: return discovered = self.discoverConnectors() for connector in discovered: self.registerConnector(connector) self._connectorsInitialized = True def refreshModels(self, force: bool = False): """Refresh models from all registered connectors. Thread-safe via _refreshLock.""" self.ensureConnectorsRegistered() currentTime = time.time() if (not force and self._lastRefresh is not None and currentTime - self._lastRefresh < self._refreshInterval): return if not self._refreshLock.acquire(blocking=False): logger.debug("refreshModels already running in another thread, skipping") return try: logger.info("Refreshing model registry...") newModels: Dict[str, AiModel] = {} for connector in self._connectors.values(): connectorType = connector.getConnectorType() try: connector.clearCache() models = connector.getCachedModels() for model in models: self._addModelToDict(model, connectorType, newModels) except Exception as e: logger.error(f"Failed to refresh models from {connectorType}: {e}") raise self._models = newModels self._lastRefresh = time.time() logger.info(f"Model registry refreshed: {len(self._models)} models available") finally: self._refreshLock.release() def getModel(self, displayName: str) -> Optional[AiModel]: """Get a specific model by displayName (displayName must be unique).""" self.refreshModels() return self._models.get(displayName) def getModels(self) -> List[AiModel]: """Get all available models.""" self.refreshModels() return list(self._models.values()) def getModelsByConnector(self, connectorType: str) -> List[AiModel]: """Get models from a specific connector.""" self.refreshModels() return [model for model in self._models.values() if model.connectorType == connectorType] def getModelsByPriority(self, priority: str) -> List[AiModel]: """Get models that have a specific priority.""" self.refreshModels() return [model for model in self._models.values() if model.priority == priority] def getAvailableModels( self, currentUser: Optional[User] = None, rbacInstance: Optional["RbacClass"] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None ) -> List[AiModel]: """Get only available models, optionally filtered by RBAC permissions. Results are cached per (user, rbac) for 30s to avoid repeated filtering on each LLM call. Args: currentUser: Optional user object for RBAC filtering rbacInstance: Optional RBAC instance for permission checks mandateId: Optional mandate context for faster RBAC (loads fewer roles) featureInstanceId: Optional feature instance for RBAC context Returns: List of available models (filtered by RBAC if user provided) """ self.refreshModels() cache_key = (currentUser.id if currentUser else "", id(rbacInstance) if rbacInstance else 0) now = time.time() if cache_key in self._getAvailableModelsCache: cached_models, cached_ts = self._getAvailableModelsCache[cache_key] if now - cached_ts < self._getAvailableModelsCacheTtl: logger.debug(f"getAvailableModels: cache hit for user={cache_key[0][:8] if cache_key[0] else 'anon'}...") return cached_models allModels = list(self._models.values()) availableModels = [model for model in allModels if model.isAvailable] # Apply RBAC filtering if user and RBAC instance provided (batch check for performance) if currentUser and rbacInstance: availableModels = self._filterModelsByRbac( availableModels, currentUser, rbacInstance, mandateId, featureInstanceId ) self._getAvailableModelsCache[cache_key] = (availableModels, now) # Prune expired entries to avoid unbounded growth expired = [k for k, (_, ts) in self._getAvailableModelsCache.items() if now - ts >= self._getAvailableModelsCacheTtl] for k in expired: del self._getAvailableModelsCache[k] unavailableCount = len(allModels) - len(availableModels) if unavailableCount > 0: unavailableModels = [m.name for m in allModels if not m.isAvailable] logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}") logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}") return availableModels def _filterModelsByRbac( self, models: List[AiModel], currentUser: User, rbacInstance: "RbacClass", mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None ) -> List[AiModel]: """Filter models based on RBAC permissions. Uses bulk check for performance.""" paths = [] model_paths = {} # model -> (connector_path, model_path) for model in models: connector_path = f"ai.model.{model.connectorType}" model_path = f"ai.model.{model.connectorType}.{model.displayName}" paths.extend([connector_path, model_path]) model_paths[id(model)] = (connector_path, model_path) # Single bulk RBAC call instead of 2*N per-model calls access = rbacInstance.checkResourceAccessBulk( currentUser, list(dict.fromkeys(paths)), mandateId, featureInstanceId ) filteredModels = [] for model in models: connector_path, model_path = model_paths[id(model)] if access.get(connector_path, False) or access.get(model_path, False): filteredModels.append(model) else: logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})") return filteredModels def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional["RbacClass"] = None) -> Optional[AiModel]: """Get a specific model by displayName, optionally checking RBAC permissions. Args: displayName: Model display name currentUser: Optional user object for RBAC check rbacInstance: Optional RBAC instance for permission check Returns: Model if found and user has access (or if no user provided), None otherwise """ self.refreshModels() model = self._models.get(displayName) if not model: return None # Check RBAC permission if user provided if currentUser and rbacInstance: connectorResourcePath = f"ai.model.{model.connectorType}" modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}" try: connPerms = rbacInstance.getUserPermissions(currentUser, AccessRuleContext.RESOURCE, connectorResourcePath) modelPerms = rbacInstance.getUserPermissions(currentUser, AccessRuleContext.RESOURCE, modelResourcePath) hasConnectorAccess = connPerms.view if connPerms else False hasModelAccess = modelPerms.view if modelPerms else False except Exception as e: logger.error(f"Error checking resource access for {modelResourcePath}: {e}") hasConnectorAccess = False hasModelAccess = False if not (hasConnectorAccess or hasModelAccess): logger.warning(f"User {currentUser.username} does not have access to model {displayName}") return None return model def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]: """Get the connector instance for a specific model by displayName.""" model = self.getModel(displayName) if model: return self._connectors.get(model.connectorType) return None def getModelStats(self) -> Dict[str, Any]: """Get statistics about the model registry.""" self.refreshModels() stats = { "totalModels": len(self._models), "availableModels": len([m for m in self._models.values() if m.isAvailable]), "connectors": len(self._connectors), "byConnector": {}, "byCapability": {}, "byPriority": {} } # Count by connector for model in self._models.values(): connector = model.connectorType if connector not in stats["byConnector"]: stats["byConnector"][connector] = 0 stats["byConnector"][connector] += 1 # Count by capability for model in self._models.values(): for capability in model.capabilities: if capability not in stats["byCapability"]: stats["byCapability"][capability] = 0 stats["byCapability"][capability] += 1 # Count by priority for model in self._models.values(): priority = model.priority if priority not in stats["byPriority"]: stats["byPriority"][priority] = 0 stats["byPriority"][priority] += 1 return stats # Global registry instance modelRegistry = ModelRegistry() # Eager pre-warm on first import: ensures connectors are ready in this process. # Critical for AI/agent performance — avoids 4–8 s latency on first request. # Runs when this module is first imported (lifespan or first AI request). def _eager_prewarm() -> None: try: modelRegistry.ensureConnectorsRegistered() modelRegistry.refreshModels(force=True) logger.info("AI connectors and model registry pre-warmed (module load)") except Exception as e: logger.warning(f"AI eager pre-warm skipped: {e}") _eager_prewarm()