322 lines
No EOL
15 KiB
Python
322 lines
No EOL
15 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Simplified model selection based on model properties and priority-based sorting.
|
|
No complex rules needed - just filter by properties and sort by priority!
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from modules.datamodels.datamodelAi import AiModel, AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_COOLDOWN_DURATION = 60.0
|
|
|
|
|
|
class ModelSelector:
|
|
"""Model selector with priority scoring and recent-failure cooldown."""
|
|
|
|
def __init__(self):
|
|
self._failureLog: Dict[str, float] = {}
|
|
logger.info("ModelSelector initialized with failure cooldown support")
|
|
|
|
def reportFailure(self, modelName: str):
|
|
"""Record that a model just failed (rate limit, error, etc.).
|
|
The model will be deprioritized for COOLDOWN_DURATION seconds."""
|
|
self._failureLog[modelName] = time.time()
|
|
logger.info(f"ModelSelector: Recorded failure for {modelName}, cooldown {_COOLDOWN_DURATION}s")
|
|
|
|
def _getCooldownPenalty(self, modelName: str) -> float:
|
|
"""Return a score penalty (0.0 = no penalty, large negative = recently failed)."""
|
|
failedAt = self._failureLog.get(modelName)
|
|
if failedAt is None:
|
|
return 0.0
|
|
elapsed = time.time() - failedAt
|
|
if elapsed > _COOLDOWN_DURATION:
|
|
del self._failureLog[modelName]
|
|
return 0.0
|
|
remaining = _COOLDOWN_DURATION - elapsed
|
|
return -(remaining / _COOLDOWN_DURATION) * 5000.0
|
|
|
|
def selectModel(self,
|
|
prompt: str,
|
|
context: str,
|
|
options: AiCallOptions,
|
|
availableModels: List[AiModel]) -> Optional[AiModel]:
|
|
"""
|
|
Select the best model using simple filtering and priority-based sorting.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
context: Context data
|
|
options: AI call options
|
|
availableModels: List of available models
|
|
|
|
Returns:
|
|
Best model for the request, or None if no suitable model found
|
|
"""
|
|
try:
|
|
# Get failover models (which includes all filtering and sorting)
|
|
failoverModelList = self.getFailoverModelList(prompt, context, options, availableModels)
|
|
|
|
if not failoverModelList:
|
|
logger.warning("No suitable models found for the request")
|
|
return None
|
|
|
|
selectedModel = failoverModelList[0] # First model is the best one
|
|
logger.info(f"Selected model: {selectedModel.name} (quality: {selectedModel.qualityRating}, cost: ${selectedModel.costPer1kTokensInput:.4f})")
|
|
return selectedModel
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error selecting model: {str(e)}")
|
|
return None
|
|
|
|
def getFailoverModelList(self,
|
|
prompt: str,
|
|
context: str,
|
|
options: AiCallOptions,
|
|
availableModels: List[AiModel]) -> List[AiModel]:
|
|
"""
|
|
Get prioritized list of models using scoring-based ranking.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
context: Context data
|
|
options: AI call options
|
|
availableModels: List of available models
|
|
|
|
Returns:
|
|
List of models sorted by score (descending)
|
|
"""
|
|
try:
|
|
promptSize = len(prompt.encode("utf-8"))
|
|
contextSize = len(context.encode("utf-8"))
|
|
totalSize = promptSize + contextSize
|
|
# Convert bytes to approximate tokens
|
|
# Balanced estimate: 1 token ≈ 3 bytes
|
|
# Note: Actual tokenization varies by content type and model
|
|
# - English text: ~4 bytes/token
|
|
# - German/European text: ~3.5 bytes/token
|
|
# - Structured data/JSON: ~2.5-3 bytes/token
|
|
# - Base64/encoded data: ~1.5-2 bytes/token
|
|
# Using 3 as balanced estimate (previously 2 which overestimated by ~2x)
|
|
bytesPerToken = 3 # Balanced estimate for mixed content
|
|
promptTokens = promptSize / bytesPerToken
|
|
contextTokens = contextSize / bytesPerToken
|
|
totalTokens = totalSize / bytesPerToken
|
|
|
|
logger.debug(f"Request sizes - Prompt: {promptTokens:.0f} tokens ({promptSize} bytes), Context: {contextTokens:.0f} tokens ({contextSize} bytes), Total: {totalTokens:.0f} tokens ({totalSize} bytes)")
|
|
|
|
# Step 1: Filter by operation type (MUST match) - check if model has this operation type
|
|
operationFiltered = []
|
|
for model in availableModels:
|
|
# Check if model has the required operation type
|
|
hasOperationType = any(ot.operationType == options.operationType for ot in model.operationTypes)
|
|
if hasOperationType:
|
|
operationFiltered.append(model)
|
|
logger.debug(f"After operation type filtering: {len(operationFiltered)} models")
|
|
|
|
if operationFiltered:
|
|
logger.debug(f"Models with {options.operationType.value}: {[m.name for m in operationFiltered]}")
|
|
|
|
# Step 2: Filter by prompt size (MUST be <= 80% of context size)
|
|
# AND by maxInputTokensPerRequest (provider rate limit / TPM)
|
|
# Note: contextLength is in tokens, so we need to compare tokens with tokens
|
|
promptFiltered = []
|
|
for model in operationFiltered:
|
|
# Check provider rate limit first (maxInputTokensPerRequest)
|
|
maxRequestTokens = getattr(model, 'maxInputTokensPerRequest', None)
|
|
if maxRequestTokens and maxRequestTokens > 0 and totalTokens > maxRequestTokens:
|
|
logger.debug(f"Model {model.name} filtered out: totalTokens={totalTokens:.0f} > maxInputTokensPerRequest={maxRequestTokens} (provider rate limit)")
|
|
continue
|
|
|
|
if model.contextLength == 0:
|
|
# No context length limit - always pass
|
|
promptFiltered.append(model)
|
|
else:
|
|
maxAllowedTokens = model.contextLength * 0.8
|
|
# Compare prompt tokens (not bytes) with model's token limit
|
|
if promptTokens <= maxAllowedTokens:
|
|
promptFiltered.append(model)
|
|
else:
|
|
logger.debug(f"Model {model.name} filtered out: promptSize={promptTokens:.0f} tokens > maxAllowed={maxAllowedTokens:.0f} tokens (80% of {model.contextLength} tokens)")
|
|
|
|
logger.debug(f"After prompt size filtering: {len(promptFiltered)} models")
|
|
|
|
if not promptFiltered and operationFiltered:
|
|
logger.warning(f"All {len(operationFiltered)} models with {options.operationType.value} were filtered out due to prompt size. Prompt: {promptTokens:.0f} tokens. Available models:")
|
|
for model in operationFiltered:
|
|
maxAllowed = model.contextLength * 0.8 / 4 if model.contextLength > 0 else "unlimited"
|
|
logger.warning(f" - {model.name}: contextLength={model.contextLength} tokens, maxAllowed={maxAllowed} tokens")
|
|
|
|
# Step 3: Calculate scores for each model (including cooldown penalties)
|
|
scoredModels = []
|
|
for model in promptFiltered:
|
|
score = self._calculateModelScore(model, promptSize, contextSize, totalSize, options)
|
|
penalty = self._getCooldownPenalty(model.name)
|
|
if penalty < 0:
|
|
logger.debug(f"Model {model.name}: base_score={score:.3f}, cooldown_penalty={penalty:.0f}")
|
|
score += penalty
|
|
scoredModels.append((model, score))
|
|
logger.debug(f"Model {model.name}: score={score:.3f}")
|
|
|
|
# Step 4: Sort by score (descending)
|
|
scoredModels.sort(key=lambda x: x[1], reverse=True)
|
|
sortedModels = [model for model, score in scoredModels]
|
|
|
|
logger.debug(f"Final sorted models: {len(sortedModels)} models")
|
|
return sortedModels
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting failover models: {str(e)}")
|
|
return []
|
|
|
|
def _calculateModelScore(self, model: AiModel, promptSize: int, contextSize: int, totalSize: int, options: AiCallOptions) -> float:
|
|
"""
|
|
Calculate a score for a model based on how well it fulfills the criteria.
|
|
Operation type rating is the PRIMARY sorting criteria (multiplied by 1000).
|
|
|
|
Args:
|
|
model: The model to score
|
|
promptSize: Size of the prompt in bytes
|
|
contextSize: Size of the context in bytes
|
|
totalSize: Total size (prompt + context) in bytes
|
|
options: AI call options
|
|
|
|
Returns:
|
|
Score for the model (higher is better)
|
|
"""
|
|
score = 0.0
|
|
|
|
# 1. PRIMARY: Operation Type Rating (multiplied by 1000 for primary sorting)
|
|
operationTypeRating = self._getOperationTypeRating(model, options.operationType)
|
|
score += operationTypeRating * 1000.0 # Primary sorting criteria
|
|
|
|
# 2. Prompt + Context size rating
|
|
if model.contextLength > 0:
|
|
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
|
|
if totalSize <= modelMaxSize:
|
|
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
|
|
score += totalSize / modelMaxSize
|
|
else:
|
|
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
|
|
score += modelMaxSize / totalSize
|
|
else:
|
|
# No context length limit
|
|
score += 1.0
|
|
|
|
# 3. Processing Mode rating
|
|
if hasattr(options, 'processingMode') and options.processingMode:
|
|
score += self._getProcessingModeRating(model.processingMode, options.processingMode)
|
|
else:
|
|
score += 1.0 # No preference
|
|
|
|
# 4. Priority rating
|
|
if hasattr(options, 'priority') and options.priority:
|
|
score += self._getPriorityRating(model, options.priority)
|
|
else:
|
|
score += 1.0 # No preference
|
|
|
|
return score
|
|
|
|
def _getOperationTypeRating(self, model: AiModel, operationType: OperationTypeEnum) -> float:
|
|
"""
|
|
Get the operation type rating for a model.
|
|
|
|
Args:
|
|
model: The model to check
|
|
operationType: The operation type to get rating for
|
|
|
|
Returns:
|
|
Rating (1-10) or 0 if model doesn't support this operation type
|
|
"""
|
|
for ot_rating in model.operationTypes:
|
|
if ot_rating.operationType == operationType:
|
|
return float(ot_rating.rating)
|
|
return 0.0 # Model doesn't support this operation type
|
|
|
|
def _getProcessingModeRating(self, modelMode: ProcessingModeEnum, requestedMode: ProcessingModeEnum) -> float:
|
|
"""Get processing mode rating based on compatibility."""
|
|
if modelMode == requestedMode:
|
|
return 1.0
|
|
|
|
# Compatibility matrix
|
|
if requestedMode == ProcessingModeEnum.BASIC:
|
|
if modelMode == ProcessingModeEnum.ADVANCED:
|
|
return 0.5
|
|
elif modelMode == ProcessingModeEnum.DETAILED:
|
|
return 0.2
|
|
|
|
elif requestedMode == ProcessingModeEnum.ADVANCED:
|
|
if modelMode == ProcessingModeEnum.BASIC:
|
|
return 0.2
|
|
elif modelMode == ProcessingModeEnum.DETAILED:
|
|
return 0.5
|
|
|
|
elif requestedMode == ProcessingModeEnum.DETAILED:
|
|
if modelMode == ProcessingModeEnum.BASIC:
|
|
return 0.2
|
|
elif modelMode == ProcessingModeEnum.ADVANCED:
|
|
return 0.5
|
|
|
|
return 0.0 # No compatibility
|
|
|
|
def _getPriorityRating(self, model: AiModel, requestedPriority: PriorityEnum) -> float:
|
|
"""Get priority rating based on model capabilities."""
|
|
if requestedPriority == PriorityEnum.BALANCED:
|
|
return 1.0
|
|
|
|
elif requestedPriority == PriorityEnum.SPEED:
|
|
return model.speedRating / 10.0
|
|
|
|
elif requestedPriority == PriorityEnum.QUALITY:
|
|
return model.qualityRating / 10.0
|
|
|
|
elif requestedPriority == PriorityEnum.COST:
|
|
# Cost priority: cost gives 1, speed gives 0.5, quality gives 0.2
|
|
# Lower cost is better, so we invert the cost rating
|
|
costRating = 1.0 - (model.costPer1kTokensInput / 0.1) # Normalize to 0-1
|
|
costRating = max(0, costRating) # Ensure non-negative
|
|
|
|
speedRating = model.speedRating / 10.0 * 0.5
|
|
qualityRating = model.qualityRating / 10.0 * 0.2
|
|
|
|
return costRating + speedRating + qualityRating
|
|
|
|
return 1.0 # Default
|
|
|
|
def _getSizeRating(self, model: AiModel, totalSize: int) -> float:
|
|
"""Get size rating for a model based on total input size."""
|
|
if model.contextLength > 0:
|
|
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
|
|
if totalSize <= modelMaxSize:
|
|
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
|
|
return totalSize / modelMaxSize
|
|
else:
|
|
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
|
|
return modelMaxSize / totalSize
|
|
else:
|
|
# No context length limit
|
|
return 1.0
|
|
|
|
|
|
def _logModelDetails(self, model: AiModel):
|
|
"""Log detailed information about a model."""
|
|
logger.info(f"Model: {model.name}")
|
|
logger.info(f" Display Name: {model.displayName}")
|
|
logger.info(f" Connector: {model.connectorType}")
|
|
logger.info(f" Context Length: {model.contextLength}")
|
|
logger.info(f" Max Tokens: {model.maxTokens}")
|
|
logger.info(f" Quality Rating: {model.qualityRating}/10")
|
|
logger.info(f" Speed Rating: {model.speedRating}/10")
|
|
logger.info(f" Cost: ${model.costPer1kTokensInput:.4f}/1k tokens")
|
|
logger.info(f" Priority: {model.priority}")
|
|
logger.info(f" Processing Mode: {model.processingMode}")
|
|
operationTypesStr = ', '.join([f"{ot.operationType.value}({ot.rating})" for ot in model.operationTypes])
|
|
logger.info(f" Operation Types: {operationTypesStr}")
|
|
|
|
|
|
# Global model selector instance
|
|
modelSelector = ModelSelector() |