295 lines
14 KiB
Python
295 lines
14 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Dynamic model registry that collects models from all AI connectors.
|
|
Implements plugin-like architecture for connector discovery.
|
|
"""
|
|
|
|
import logging
|
|
import importlib
|
|
import os
|
|
from typing import Dict, List, Optional, Any
|
|
from modules.datamodels.datamodelAi import AiModel
|
|
from modules.aicore.aicoreBase import BaseConnectorAi
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.security.rbacHelpers import checkResourceAccess
|
|
from modules.security.rbac import RbacClass
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# TODO TESTING: Override maxTokens for all models during testing
|
|
# Set to None to disable override, or set to an integer (e.g., 20000) to override all models
|
|
TESTING_MAX_TOKENS_OVERRIDE: Optional[int] = None # TODO TESTING: Set to None to disable
|
|
|
|
class ModelRegistry:
|
|
"""Dynamic registry for AI models from all connectors."""
|
|
|
|
def __init__(self):
|
|
self._models: Dict[str, AiModel] = {}
|
|
self._connectors: Dict[str, BaseConnectorAi] = {}
|
|
self._lastRefresh: Optional[float] = None
|
|
self._refreshInterval: float = 300.0 # 5 minutes
|
|
|
|
def registerConnector(self, connector: BaseConnectorAi):
|
|
"""Register a connector and collect its models."""
|
|
connectorType = connector.getConnectorType()
|
|
|
|
# If connector already registered, skip re-registration to avoid duplicate models
|
|
if connectorType in self._connectors:
|
|
logger.debug(f"Connector {connectorType} already registered, skipping re-registration")
|
|
return
|
|
|
|
self._connectors[connectorType] = connector
|
|
|
|
# Collect models from this connector
|
|
try:
|
|
models = connector.getCachedModels()
|
|
for model in models:
|
|
# Validate displayName uniqueness
|
|
if model.displayName in self._models:
|
|
existingModel = self._models[model.displayName]
|
|
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connectorType}). displayName must be unique."
|
|
logger.error(errorMsg)
|
|
raise ValueError(errorMsg)
|
|
|
|
# TODO TESTING: Override maxTokens if testing override is enabled
|
|
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
|
|
originalMaxTokens = model.maxTokens
|
|
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
|
|
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
|
|
|
|
# Use displayName as the key (must be unique)
|
|
self._models[model.displayName] = model
|
|
logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to register models from {connectorType}: {e}")
|
|
raise
|
|
|
|
def discoverConnectors(self) -> List[BaseConnectorAi]:
|
|
"""Auto-discover connectors by scanning aicorePlugin*.py files."""
|
|
connectors = []
|
|
connectorDir = os.path.dirname(__file__)
|
|
|
|
# Scan for connector files
|
|
for filename in os.listdir(connectorDir):
|
|
if filename.startswith('aicorePlugin') and filename.endswith('.py'):
|
|
moduleName = filename[:-3] # Remove .py extension
|
|
|
|
try:
|
|
# Import the module
|
|
module = importlib.import_module(f'modules.aicore.{moduleName}')
|
|
|
|
# Find connector classes (classes that inherit from BaseConnectorAi)
|
|
for attrName in dir(module):
|
|
attr = getattr(module, attrName)
|
|
if (isinstance(attr, type) and
|
|
issubclass(attr, BaseConnectorAi) and
|
|
attr != BaseConnectorAi):
|
|
|
|
# Instantiate the connector
|
|
connector = attr()
|
|
connectors.append(connector)
|
|
logger.info(f"Discovered connector: {connector.getConnectorType()}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover connector from {filename}: {e}")
|
|
|
|
return connectors
|
|
|
|
def refreshModels(self, force: bool = False):
|
|
"""Refresh models from all registered connectors."""
|
|
import time
|
|
|
|
currentTime = time.time()
|
|
|
|
# Check if refresh is needed
|
|
if (not force and
|
|
self._lastRefresh is not None and
|
|
currentTime - self._lastRefresh < self._refreshInterval):
|
|
return
|
|
|
|
logger.info("Refreshing model registry...")
|
|
|
|
# Clear existing models
|
|
self._models.clear()
|
|
|
|
# Re-register all connectors
|
|
for connector in self._connectors.values():
|
|
try:
|
|
connector.clearCache() # Clear connector cache
|
|
models = connector.getCachedModels()
|
|
for model in models:
|
|
# Validate displayName uniqueness
|
|
if model.displayName in self._models:
|
|
existingModel = self._models[model.displayName]
|
|
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connector.getConnectorType()}). displayName must be unique."
|
|
logger.error(errorMsg)
|
|
raise ValueError(errorMsg)
|
|
|
|
# TODO TESTING: Override maxTokens if testing override is enabled
|
|
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
|
|
originalMaxTokens = model.maxTokens
|
|
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
|
|
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
|
|
|
|
# Use displayName as the key (must be unique)
|
|
self._models[model.displayName] = model
|
|
except Exception as e:
|
|
logger.error(f"Failed to refresh models from {connector.getConnectorType()}: {e}")
|
|
raise
|
|
|
|
self._lastRefresh = currentTime
|
|
logger.info(f"Model registry refreshed: {len(self._models)} models available")
|
|
|
|
def getModel(self, displayName: str) -> Optional[AiModel]:
|
|
"""Get a specific model by displayName (displayName must be unique)."""
|
|
self.refreshModels()
|
|
return self._models.get(displayName)
|
|
|
|
def getModels(self) -> List[AiModel]:
|
|
"""Get all available models."""
|
|
self.refreshModels()
|
|
return list(self._models.values())
|
|
|
|
def getModelsByConnector(self, connectorType: str) -> List[AiModel]:
|
|
"""Get models from a specific connector."""
|
|
self.refreshModels()
|
|
return [model for model in self._models.values() if model.connectorType == connectorType]
|
|
|
|
|
|
def getModelsByPriority(self, priority: str) -> List[AiModel]:
|
|
"""Get models that have a specific priority."""
|
|
self.refreshModels()
|
|
return [model for model in self._models.values() if model.priority == priority]
|
|
|
|
def getAvailableModels(self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> List[AiModel]:
|
|
"""Get only available models, optionally filtered by RBAC permissions.
|
|
|
|
Args:
|
|
currentUser: Optional user object for RBAC filtering
|
|
rbacInstance: Optional RBAC instance for permission checks
|
|
|
|
Returns:
|
|
List of available models (filtered by RBAC if user provided)
|
|
"""
|
|
self.refreshModels()
|
|
allModels = list(self._models.values())
|
|
availableModels = [model for model in allModels if model.isAvailable]
|
|
|
|
# Apply RBAC filtering if user and RBAC instance provided
|
|
if currentUser and rbacInstance:
|
|
availableModels = self._filterModelsByRbac(availableModels, currentUser, rbacInstance)
|
|
|
|
unavailableCount = len(allModels) - len(availableModels)
|
|
if unavailableCount > 0:
|
|
unavailableModels = [m.name for m in allModels if not m.isAvailable]
|
|
logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}")
|
|
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
|
|
return availableModels
|
|
|
|
def _filterModelsByRbac(self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass) -> List[AiModel]:
|
|
"""Filter models based on RBAC permissions.
|
|
|
|
Args:
|
|
models: List of models to filter
|
|
currentUser: Current user object
|
|
rbacInstance: RBAC instance for permission checks
|
|
|
|
Returns:
|
|
Filtered list of models that user has access to
|
|
"""
|
|
filteredModels = []
|
|
for model in models:
|
|
# Check access at both connector level and model level
|
|
connectorResourcePath = f"ai.model.{model.connectorType}"
|
|
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
|
|
|
|
# User needs access to either connector (all models) or specific model
|
|
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
|
|
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
|
|
|
|
if hasConnectorAccess or hasModelAccess:
|
|
filteredModels.append(model)
|
|
else:
|
|
logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})")
|
|
|
|
return filteredModels
|
|
|
|
def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]:
|
|
"""Get a specific model by displayName, optionally checking RBAC permissions.
|
|
|
|
Args:
|
|
displayName: Model display name
|
|
currentUser: Optional user object for RBAC check
|
|
rbacInstance: Optional RBAC instance for permission check
|
|
|
|
Returns:
|
|
Model if found and user has access (or if no user provided), None otherwise
|
|
"""
|
|
self.refreshModels()
|
|
model = self._models.get(displayName)
|
|
|
|
if not model:
|
|
return None
|
|
|
|
# Check RBAC permission if user provided
|
|
if currentUser and rbacInstance:
|
|
connectorResourcePath = f"ai.model.{model.connectorType}"
|
|
modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
|
|
|
|
hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
|
|
hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
|
|
|
|
if not (hasConnectorAccess or hasModelAccess):
|
|
logger.warning(f"User {currentUser.username} does not have access to model {displayName}")
|
|
return None
|
|
|
|
return model
|
|
|
|
def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]:
|
|
"""Get the connector instance for a specific model by displayName."""
|
|
model = self.getModel(displayName)
|
|
if model:
|
|
return self._connectors.get(model.connectorType)
|
|
return None
|
|
|
|
def getModelStats(self) -> Dict[str, Any]:
|
|
"""Get statistics about the model registry."""
|
|
self.refreshModels()
|
|
|
|
stats = {
|
|
"totalModels": len(self._models),
|
|
"availableModels": len([m for m in self._models.values() if m.isAvailable]),
|
|
"connectors": len(self._connectors),
|
|
"byConnector": {},
|
|
"byCapability": {},
|
|
"byPriority": {}
|
|
}
|
|
|
|
# Count by connector
|
|
for model in self._models.values():
|
|
connector = model.connectorType
|
|
if connector not in stats["byConnector"]:
|
|
stats["byConnector"][connector] = 0
|
|
stats["byConnector"][connector] += 1
|
|
|
|
# Count by capability
|
|
for model in self._models.values():
|
|
for capability in model.capabilities:
|
|
if capability not in stats["byCapability"]:
|
|
stats["byCapability"][capability] = 0
|
|
stats["byCapability"][capability] += 1
|
|
|
|
# Count by priority
|
|
for model in self._models.values():
|
|
priority = model.priority
|
|
if priority not in stats["byPriority"]:
|
|
stats["byPriority"][priority] = 0
|
|
stats["byPriority"][priority] += 1
|
|
|
|
return stats
|
|
|
|
|
|
# Global registry instance
|
|
modelRegistry = ModelRegistry()
|