gateway/modules/aicore/aicoreModelSelector.py
2026-01-23 01:10:00 +01:00

281 lines
No EOL
12 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Simplified model selection based on model properties and priority-based sorting.
No complex rules needed - just filter by properties and sort by priority!
"""
import logging
from typing import List, Dict, Any, Optional
from modules.datamodels.datamodelAi import AiModel, AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
# Configure logger
logger = logging.getLogger(__name__)
class ModelSelector:
"""Simple model selector based on properties and priority-based sorting."""
def __init__(self):
logger.info("ModelSelector initialized with simplified approach")
def selectModel(self,
prompt: str,
context: str,
options: AiCallOptions,
availableModels: List[AiModel]) -> Optional[AiModel]:
"""
Select the best model using simple filtering and priority-based sorting.
Args:
prompt: User prompt
context: Context data
options: AI call options
availableModels: List of available models
Returns:
Best model for the request, or None if no suitable model found
"""
try:
# Get failover models (which includes all filtering and sorting)
failoverModelList = self.getFailoverModelList(prompt, context, options, availableModels)
if not failoverModelList:
logger.warning("No suitable models found for the request")
return None
selectedModel = failoverModelList[0] # First model is the best one
logger.info(f"Selected model: {selectedModel.name} (quality: {selectedModel.qualityRating}, cost: ${selectedModel.costPer1kTokensInput:.4f})")
return selectedModel
except Exception as e:
logger.error(f"Error selecting model: {str(e)}")
return None
def getFailoverModelList(self,
prompt: str,
context: str,
options: AiCallOptions,
availableModels: List[AiModel]) -> List[AiModel]:
"""
Get prioritized list of models using scoring-based ranking.
Args:
prompt: User prompt
context: Context data
options: AI call options
availableModels: List of available models
Returns:
List of models sorted by score (descending)
"""
try:
promptSize = len(prompt.encode("utf-8"))
contextSize = len(context.encode("utf-8"))
totalSize = promptSize + contextSize
# Convert bytes to approximate tokens (1 token ≈ 4 bytes)
promptTokens = promptSize / 4
contextTokens = contextSize / 4
totalTokens = totalSize / 4
logger.debug(f"Request sizes - Prompt: {promptTokens:.0f} tokens ({promptSize} bytes), Context: {contextTokens:.0f} tokens ({contextSize} bytes), Total: {totalTokens:.0f} tokens ({totalSize} bytes)")
# Step 1: Filter by operation type (MUST match) - check if model has this operation type
operationFiltered = []
for model in availableModels:
# Check if model has the required operation type
hasOperationType = any(ot.operationType == options.operationType for ot in model.operationTypes)
if hasOperationType:
operationFiltered.append(model)
logger.debug(f"After operation type filtering: {len(operationFiltered)} models")
if operationFiltered:
logger.debug(f"Models with {options.operationType.value}: {[m.name for m in operationFiltered]}")
# Step 2: Filter by prompt size (MUST be <= 80% of context size)
# Note: contextLength is in tokens, so we need to compare tokens with tokens
promptFiltered = []
for model in operationFiltered:
if model.contextLength == 0:
# No context length limit - always pass
promptFiltered.append(model)
else:
maxAllowedTokens = model.contextLength * 0.8
# Compare prompt tokens (not bytes) with model's token limit
if promptTokens <= maxAllowedTokens:
promptFiltered.append(model)
else:
logger.debug(f"Model {model.name} filtered out: promptSize={promptTokens:.0f} tokens > maxAllowed={maxAllowedTokens:.0f} tokens (80% of {model.contextLength} tokens)")
logger.debug(f"After prompt size filtering: {len(promptFiltered)} models")
if not promptFiltered and operationFiltered:
logger.warning(f"All {len(operationFiltered)} models with {options.operationType.value} were filtered out due to prompt size. Prompt: {promptTokens:.0f} tokens. Available models:")
for model in operationFiltered:
maxAllowed = model.contextLength * 0.8 / 4 if model.contextLength > 0 else "unlimited"
logger.warning(f" - {model.name}: contextLength={model.contextLength} tokens, maxAllowed={maxAllowed} tokens")
# Step 3: Calculate scores for each model
scoredModels = []
for model in promptFiltered:
score = self._calculateModelScore(model, promptSize, contextSize, totalSize, options)
scoredModels.append((model, score))
logger.debug(f"Model {model.name}: score={score:.3f}")
# Step 4: Sort by score (descending)
scoredModels.sort(key=lambda x: x[1], reverse=True)
sortedModels = [model for model, score in scoredModels]
logger.debug(f"Final sorted models: {len(sortedModels)} models")
return sortedModels
except Exception as e:
logger.error(f"Error getting failover models: {str(e)}")
return []
def _calculateModelScore(self, model: AiModel, promptSize: int, contextSize: int, totalSize: int, options: AiCallOptions) -> float:
"""
Calculate a score for a model based on how well it fulfills the criteria.
Operation type rating is the PRIMARY sorting criteria (multiplied by 1000).
Args:
model: The model to score
promptSize: Size of the prompt in bytes
contextSize: Size of the context in bytes
totalSize: Total size (prompt + context) in bytes
options: AI call options
Returns:
Score for the model (higher is better)
"""
score = 0.0
# 1. PRIMARY: Operation Type Rating (multiplied by 1000 for primary sorting)
operationTypeRating = self._getOperationTypeRating(model, options.operationType)
score += operationTypeRating * 1000.0 # Primary sorting criteria
# 2. Prompt + Context size rating
if model.contextLength > 0:
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
if totalSize <= modelMaxSize:
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
score += totalSize / modelMaxSize
else:
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
score += modelMaxSize / totalSize
else:
# No context length limit
score += 1.0
# 3. Processing Mode rating
if hasattr(options, 'processingMode') and options.processingMode:
score += self._getProcessingModeRating(model.processingMode, options.processingMode)
else:
score += 1.0 # No preference
# 4. Priority rating
if hasattr(options, 'priority') and options.priority:
score += self._getPriorityRating(model, options.priority)
else:
score += 1.0 # No preference
return score
def _getOperationTypeRating(self, model: AiModel, operationType: OperationTypeEnum) -> float:
"""
Get the operation type rating for a model.
Args:
model: The model to check
operationType: The operation type to get rating for
Returns:
Rating (1-10) or 0 if model doesn't support this operation type
"""
for ot_rating in model.operationTypes:
if ot_rating.operationType == operationType:
return float(ot_rating.rating)
return 0.0 # Model doesn't support this operation type
def _getProcessingModeRating(self, modelMode: ProcessingModeEnum, requestedMode: ProcessingModeEnum) -> float:
"""Get processing mode rating based on compatibility."""
if modelMode == requestedMode:
return 1.0
# Compatibility matrix
if requestedMode == ProcessingModeEnum.BASIC:
if modelMode == ProcessingModeEnum.ADVANCED:
return 0.5
elif modelMode == ProcessingModeEnum.DETAILED:
return 0.2
elif requestedMode == ProcessingModeEnum.ADVANCED:
if modelMode == ProcessingModeEnum.BASIC:
return 0.2
elif modelMode == ProcessingModeEnum.DETAILED:
return 0.5
elif requestedMode == ProcessingModeEnum.DETAILED:
if modelMode == ProcessingModeEnum.BASIC:
return 0.2
elif modelMode == ProcessingModeEnum.ADVANCED:
return 0.5
return 0.0 # No compatibility
def _getPriorityRating(self, model: AiModel, requestedPriority: PriorityEnum) -> float:
"""Get priority rating based on model capabilities."""
if requestedPriority == PriorityEnum.BALANCED:
return 1.0
elif requestedPriority == PriorityEnum.SPEED:
return model.speedRating / 10.0
elif requestedPriority == PriorityEnum.QUALITY:
return model.qualityRating / 10.0
elif requestedPriority == PriorityEnum.COST:
# Cost priority: cost gives 1, speed gives 0.5, quality gives 0.2
# Lower cost is better, so we invert the cost rating
costRating = 1.0 - (model.costPer1kTokensInput / 0.1) # Normalize to 0-1
costRating = max(0, costRating) # Ensure non-negative
speedRating = model.speedRating / 10.0 * 0.5
qualityRating = model.qualityRating / 10.0 * 0.2
return costRating + speedRating + qualityRating
return 1.0 # Default
def _getSizeRating(self, model: AiModel, totalSize: int) -> float:
"""Get size rating for a model based on total input size."""
if model.contextLength > 0:
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
if totalSize <= modelMaxSize:
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
return totalSize / modelMaxSize
else:
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
return modelMaxSize / totalSize
else:
# No context length limit
return 1.0
def _logModelDetails(self, model: AiModel):
"""Log detailed information about a model."""
logger.info(f"Model: {model.name}")
logger.info(f" Display Name: {model.displayName}")
logger.info(f" Connector: {model.connectorType}")
logger.info(f" Context Length: {model.contextLength}")
logger.info(f" Max Tokens: {model.maxTokens}")
logger.info(f" Quality Rating: {model.qualityRating}/10")
logger.info(f" Speed Rating: {model.speedRating}/10")
logger.info(f" Cost: ${model.costPer1kTokensInput:.4f}/1k tokens")
logger.info(f" Priority: {model.priority}")
logger.info(f" Processing Mode: {model.processingMode}")
operationTypesStr = ', '.join([f"{ot.operationType.value}({ot.rating})" for ot in model.operationTypes])
logger.info(f" Operation Types: {operationTypesStr}")
# Global model selector instance
modelSelector = ModelSelector()