gateway/modules/aicore/aicoreModelRegistry.py

307 lines
14 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Dynamic model registry that collects models from all AI connectors.
Implements plugin-like architecture for connector discovery.
"""
import logging
import importlib
import os
from typing import Dict, List, Optional, Any
from modules.datamodels.datamodelAi import AiModel
from modules.aicore.aicoreBase import BaseConnectorAi
from modules.datamodels.datamodelUam import User
from modules.security.rbacHelpers import checkResourceAccess
from modules.security.rbac import RbacClass
from modules.connectors.connectorDbPostgre import DatabaseConnector
logger = logging.getLogger(__name__)
# TODO TESTING: Override maxTokens for all models during testing
# Set to None to disable override, or set to an integer (e.g., 20000) to override all models
TESTING_MAX_TOKENS_OVERRIDE: Optional[int] = None # TODO TESTING: Set to None to disable
class ModelRegistry:
"""Dynamic registry for AI models from all connectors."""
def __init__(self):
self._models: Dict[str, AiModel] = {}
self._connectors: Dict[str, BaseConnectorAi] = {}
self._lastRefresh: Optional[float] = None
self._refreshInterval: float = 300.0 # 5 minutes
self._connectorsInitialized: bool = False
def registerConnector(self, connector: BaseConnectorAi):
"""Register a connector and collect its models."""
connectorType = connector.getConnectorType()
# If connector already registered, skip re-registration to avoid duplicate models
if connectorType in self._connectors:
logger.debug(f"Connector {connectorType} already registered, skipping re-registration")
return
self._connectors[connectorType] = connector
# Collect models from this connector
try:
models = connector.getCachedModels()
for model in models:
# Validate displayName uniqueness
if model.displayName in self._models:
existingModel = self._models[model.displayName]
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connectorType}). displayName must be unique."
logger.error(errorMsg)
raise ValueError(errorMsg)
# TODO TESTING: Override maxTokens if testing override is enabled
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
originalMaxTokens = model.maxTokens
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
# Use displayName as the key (must be unique)
self._models[model.displayName] = model
logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}")
except Exception as e:
logger.error(f"Failed to register models from {connectorType}: {e}")
raise
def discoverConnectors(self) -> List[BaseConnectorAi]:
"""Auto-discover connectors by scanning aicorePlugin*.py files."""
connectors = []
connectorDir = os.path.dirname(__file__)
# Scan for connector files
for filename in os.listdir(connectorDir):
if filename.startswith('aicorePlugin') and filename.endswith('.py'):
moduleName = filename[:-3] # Remove .py extension
try:
# Import the module
module = importlib.import_module(f'modules.aicore.{moduleName}')
# Find connector classes (classes that inherit from BaseConnectorAi)
for attrName in dir(module):
attr = getattr(module, attrName)
if (isinstance(attr, type) and
issubclass(attr, BaseConnectorAi) and
attr != BaseConnectorAi):
# Instantiate the connector
connector = attr()
connectors.append(connector)
logger.info(f"Discovered connector: {connector.getConnectorType()}")
except Exception as e:
logger.warning(f"Failed to discover connector from {filename}: {e}")
return connectors
def ensureConnectorsRegistered(self):
"""Register connectors once to avoid per-request discovery."""
if self._connectorsInitialized:
return
discovered = self.discoverConnectors()
for connector in discovered:
self.registerConnector(connector)
self._connectorsInitialized = True
def refreshModels(self, force: bool = False):
"""Refresh models from all registered connectors."""
import time
self.ensureConnectorsRegistered()
currentTime = time.time()
# Check if refresh is needed
if (not force and
self._lastRefresh is not None and
currentTime - self._lastRefresh < self._refreshInterval):
return
logger.info("Refreshing model registry...")
# Clear existing models
self._models.clear()
# Re-register all connectors
for connector in self._connectors.values():
try:
connector.clearCache() # Clear connector cache
models = connector.getCachedModels()
for model in models:
# Validate displayName uniqueness
if model.displayName in self._models:
existingModel = self._models[model.displayName]
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connector.getConnectorType()}). displayName must be unique."
logger.error(errorMsg)
raise ValueError(errorMsg)
# TODO TESTING: Override maxTokens if testing override is enabled
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
originalMaxTokens = model.maxTokens
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
# Use displayName as the key (must be unique)
self._models[model.displayName] = model
except Exception as e:
logger.error(f"Failed to refresh models from {connector.getConnectorType()}: {e}")
raise
self._lastRefresh = currentTime
logger.info(f"Model registry refreshed: {len(self._models)} models available")
def getModel(self, displayName: str) -> Optional[AiModel]:
"""Get a specific model by displayName (displayName must be unique)."""
self.refreshModels()
return self._models.get(displayName)
def getModels(self) -> List[AiModel]:
"""Get all available models."""
self.refreshModels()
return list(self._models.values())
def getModelsByConnector(self, connectorType: str) -> List[AiModel]:
"""Get models from a specific connector."""
self.refreshModels()
return [model for model in self._models.values() if model.connectorType == connectorType]
def getModelsByPriority(self, priority: str) -> List[AiModel]:
"""Get models that have a specific priority."""
self.refreshModels()
return [model for model in self._models.values() if model.priority == priority]
def getAvailableModels(self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> List[AiModel]:
"""Get only available models, optionally filtered by RBAC permissions.
Args:
currentUser: Optional user object for RBAC filtering
rbacInstance: Optional RBAC instance for permission checks
Returns:
List of available models (filtered by RBAC if user provided)
"""
self.refreshModels()
allModels = list(self._models.values())
availableModels = [model for model in allModels if model.isAvailable]
# Apply RBAC filtering if user and RBAC instance provided
if currentUser and rbacInstance:
availableModels = self._filterModelsByRbac(availableModels, currentUser, rbacInstance)
unavailableCount = len(allModels) - len(availableModels)
if unavailableCount > 0:
unavailableModels = [m.name for m in allModels if not m.isAvailable]
logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}")
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
return availableModels
def _filterModelsByRbac(self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass) -> List[AiModel]:
"""Filter models based on RBAC permissions.
Args:
models: List of models to filter
currentUser: Current user object
rbacInstance: RBAC instance for permission checks
Returns:
Filtered list of models that user has access to
"""
filteredModels = []
for model in models:
# Check access at both connector level and model level
connectorResourcePath = f"ai.model.{model.connectorType}"
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
# User needs access to either connector (all models) or specific model
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
if hasConnectorAccess or hasModelAccess:
filteredModels.append(model)
else:
logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})")
return filteredModels
def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]:
"""Get a specific model by displayName, optionally checking RBAC permissions.
Args:
displayName: Model display name
currentUser: Optional user object for RBAC check
rbacInstance: Optional RBAC instance for permission check
Returns:
Model if found and user has access (or if no user provided), None otherwise
"""
self.refreshModels()
model = self._models.get(displayName)
if not model:
return None
# Check RBAC permission if user provided
if currentUser and rbacInstance:
connectorResourcePath = f"ai.model.{model.connectorType}"
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
if not (hasConnectorAccess or hasModelAccess):
logger.warning(f"User {currentUser.username} does not have access to model {displayName}")
return None
return model
def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]:
"""Get the connector instance for a specific model by displayName."""
model = self.getModel(displayName)
if model:
return self._connectors.get(model.connectorType)
return None
def getModelStats(self) -> Dict[str, Any]:
"""Get statistics about the model registry."""
self.refreshModels()
stats = {
"totalModels": len(self._models),
"availableModels": len([m for m in self._models.values() if m.isAvailable]),
"connectors": len(self._connectors),
"byConnector": {},
"byCapability": {},
"byPriority": {}
}
# Count by connector
for model in self._models.values():
connector = model.connectorType
if connector not in stats["byConnector"]:
stats["byConnector"][connector] = 0
stats["byConnector"][connector] += 1
# Count by capability
for model in self._models.values():
for capability in model.capabilities:
if capability not in stats["byCapability"]:
stats["byCapability"][capability] = 0
stats["byCapability"][capability] += 1
# Count by priority
for model in self._models.values():
priority = model.priority
if priority not in stats["byPriority"]:
stats["byPriority"][priority] = 0
stats["byPriority"][priority] += 1
return stats
# Global registry instance
modelRegistry = ModelRegistry()