gateway/modules/aicore/aicoreModelSelector.py

225 lines
No EOL
9.3 KiB
Python

"""
Simplified model selection based on model properties and priority-based sorting.
No complex rules needed - just filter by properties and sort by priority!
"""
import logging
from typing import List, Dict, Any, Optional
from modules.datamodels.datamodelAi import AiModel, AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
# Configure logger
logger = logging.getLogger(__name__)
class ModelSelector:
"""Simple model selector based on properties and priority-based sorting."""
def __init__(self):
logger.info("ModelSelector initialized with simplified approach")
def selectModel(self,
prompt: str,
context: str,
options: AiCallOptions,
availableModels: List[AiModel]) -> Optional[AiModel]:
"""
Select the best model using simple filtering and priority-based sorting.
Args:
prompt: User prompt
context: Context data
options: AI call options
availableModels: List of available models
Returns:
Best model for the request, or None if no suitable model found
"""
try:
# Get failover models (which includes all filtering and sorting)
failoverModelList = self.getFailoverModelList(prompt, context, options, availableModels)
if not failoverModelList:
logger.warning("No suitable models found for the request")
return None
selectedModel = failoverModelList[0] # First model is the best one
logger.info(f"Selected model: {selectedModel.name} (quality: {selectedModel.qualityRating}, cost: ${selectedModel.costPer1kTokensInput:.4f})")
return selectedModel
except Exception as e:
logger.error(f"Error selecting model: {str(e)}")
return None
def getFailoverModelList(self,
prompt: str,
context: str,
options: AiCallOptions,
availableModels: List[AiModel]) -> List[AiModel]:
"""
Get prioritized list of models using scoring-based ranking.
Args:
prompt: User prompt
context: Context data
options: AI call options
availableModels: List of available models
Returns:
List of models sorted by score (descending)
"""
try:
promptSize = len(prompt.encode("utf-8"))
contextSize = len(context.encode("utf-8"))
totalSize = promptSize + contextSize
# Step 1: Filter by operation type (MUST match)
operationFiltered = [m for m in availableModels if options.operationType in m.operationTypes]
logger.debug(f"After operation type filtering: {len(operationFiltered)} models")
# Step 2: Filter by prompt size (MUST be <= 80% of context size)
promptFiltered = [m for m in operationFiltered if m.contextLength == 0 or promptSize <= m.contextLength * 0.8]
logger.debug(f"After prompt size filtering: {len(promptFiltered)} models")
# Step 3: Calculate scores for each model
scoredModels = []
for model in promptFiltered:
score = self._calculateModelScore(model, promptSize, contextSize, totalSize, options)
scoredModels.append((model, score))
logger.debug(f"Model {model.name}: score={score:.3f}")
# Step 4: Sort by score (descending)
scoredModels.sort(key=lambda x: x[1], reverse=True)
sortedModels = [model for model, score in scoredModels]
logger.debug(f"Final sorted models: {len(sortedModels)} models")
return sortedModels
except Exception as e:
logger.error(f"Error getting failover models: {str(e)}")
return []
def _calculateModelScore(self, model: AiModel, promptSize: int, contextSize: int, totalSize: int, options: AiCallOptions) -> float:
"""
Calculate a score for a model based on how well it fulfills the criteria.
Args:
model: The model to score
promptSize: Size of the prompt in bytes
contextSize: Size of the context in bytes
totalSize: Total size (prompt + context) in bytes
options: AI call options
Returns:
Score for the model (higher is better)
"""
score = 0.0
# 1. Prompt + Context size rating
if model.contextLength > 0:
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
if totalSize <= modelMaxSize:
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
score += totalSize / modelMaxSize
else:
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
score += modelMaxSize / totalSize
else:
# No context length limit
score += 1.0
# 2. Processing Mode rating
if hasattr(options, 'processingMode') and options.processingMode:
score += self._getProcessingModeRating(model.processingMode, options.processingMode)
else:
score += 1.0 # No preference
# 3. Priority rating
if hasattr(options, 'priority') and options.priority:
score += self._getPriorityRating(model, options.priority)
else:
score += 1.0 # No preference
return score
def _getProcessingModeRating(self, modelMode: ProcessingModeEnum, requestedMode: ProcessingModeEnum) -> float:
"""Get processing mode rating based on compatibility."""
if modelMode == requestedMode:
return 1.0
# Compatibility matrix
if requestedMode == ProcessingModeEnum.BASIC:
if modelMode == ProcessingModeEnum.ADVANCED:
return 0.5
elif modelMode == ProcessingModeEnum.DETAILED:
return 0.2
elif requestedMode == ProcessingModeEnum.ADVANCED:
if modelMode == ProcessingModeEnum.BASIC:
return 0.2
elif modelMode == ProcessingModeEnum.DETAILED:
return 0.5
elif requestedMode == ProcessingModeEnum.DETAILED:
if modelMode == ProcessingModeEnum.BASIC:
return 0.2
elif modelMode == ProcessingModeEnum.ADVANCED:
return 0.5
return 0.0 # No compatibility
def _getPriorityRating(self, model: AiModel, requestedPriority: PriorityEnum) -> float:
"""Get priority rating based on model capabilities."""
if requestedPriority == PriorityEnum.BALANCED:
return 1.0
elif requestedPriority == PriorityEnum.SPEED:
return model.speedRating / 10.0
elif requestedPriority == PriorityEnum.QUALITY:
return model.qualityRating / 10.0
elif requestedPriority == PriorityEnum.COST:
# Cost priority: cost gives 1, speed gives 0.5, quality gives 0.2
# Lower cost is better, so we invert the cost rating
costRating = 1.0 - (model.costPer1kTokensInput / 0.1) # Normalize to 0-1
costRating = max(0, costRating) # Ensure non-negative
speedRating = model.speedRating / 10.0 * 0.5
qualityRating = model.qualityRating / 10.0 * 0.2
return costRating + speedRating + qualityRating
return 1.0 # Default
def _getSizeRating(self, model: AiModel, totalSize: int) -> float:
"""Get size rating for a model based on total input size."""
if model.contextLength > 0:
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
if totalSize <= modelMaxSize:
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
return totalSize / modelMaxSize
else:
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
return modelMaxSize / totalSize
else:
# No context length limit
return 1.0
def _logModelDetails(self, model: AiModel):
"""Log detailed information about a model."""
logger.info(f"Model: {model.name}")
logger.info(f" Display Name: {model.displayName}")
logger.info(f" Connector: {model.connectorType}")
logger.info(f" Context Length: {model.contextLength}")
logger.info(f" Max Tokens: {model.maxTokens}")
logger.info(f" Quality Rating: {model.qualityRating}/10")
logger.info(f" Speed Rating: {model.speedRating}/10")
logger.info(f" Cost: ${model.costPer1kTokensInput:.4f}/1k tokens")
logger.info(f" Capabilities: {', '.join(model.capabilities)}")
logger.info(f" Priority: {model.priority}")
logger.info(f" Processing Mode: {model.processingMode}")
logger.info(f" Operation Types: {', '.join(model.operationTypes)}")
# Global model selector instance
modelSelector = ModelSelector()