279 lines
12 KiB
Python
279 lines
12 KiB
Python
"""
|
|
Dynamic model selector using configurable rules and scoring.
|
|
"""
|
|
|
|
import logging
|
|
from typing import List, Optional, Dict, Any, Tuple
|
|
from modules.datamodels.datamodelAi import AiModel, AiCallOptions, OperationType, Priority, ProcessingMode, ModelTags
|
|
from modules.aicore.aicoreModelSelectionConfig import model_selection_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelSelector:
|
|
"""Dynamic model selector using configurable rules."""
|
|
|
|
def __init__(self):
|
|
self.config = model_selection_config
|
|
|
|
def selectModel(self,
|
|
prompt: str,
|
|
context: str,
|
|
options: AiCallOptions,
|
|
available_models: List[AiModel]) -> Optional[AiModel]:
|
|
"""
|
|
Select the best model based on configurable rules and scoring.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
context: Context data
|
|
options: AI call options
|
|
available_models: List of available models to choose from
|
|
|
|
Returns:
|
|
Selected model or None if no suitable model found
|
|
"""
|
|
if not available_models:
|
|
logger.warning("No models available for selection")
|
|
return None
|
|
|
|
logger.info(f"Selecting model for operation: {options.operationType}, priority: {options.priority}")
|
|
|
|
# Calculate input size
|
|
input_size = len(prompt.encode("utf-8")) + len(context.encode("utf-8"))
|
|
|
|
# Get applicable rules
|
|
rules = self.config.getRulesForOperation(options.operationType)
|
|
logger.debug(f"Found {len(rules)} applicable rules for {options.operationType}")
|
|
|
|
# Score each model
|
|
scored_models = []
|
|
for model in available_models:
|
|
if not model.isAvailable:
|
|
continue
|
|
|
|
score = self._calculateModelScore(model, input_size, options, rules)
|
|
if score > 0: # Only consider models with positive scores
|
|
scored_models.append((model, score))
|
|
logger.debug(f"Model {model.name}: score={score:.2f}")
|
|
|
|
if not scored_models:
|
|
logger.warning("No models passed the selection criteria, trying fallback criteria")
|
|
# Try fallback criteria
|
|
fallback_criteria = self.getFallbackCriteria(options.operationType)
|
|
return self._selectWithFallbackCriteria(available_models, fallback_criteria, input_size, options)
|
|
|
|
# Sort by score (highest first)
|
|
scored_models.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
selected_model = scored_models[0][0]
|
|
selected_score = scored_models[0][1]
|
|
|
|
logger.info(f"Selected model: {selected_model.name} (score: {selected_score:.2f})")
|
|
|
|
# Log selection details
|
|
self._logSelectionDetails(selected_model, input_size, options)
|
|
|
|
return selected_model
|
|
|
|
def _calculateModelScore(self,
|
|
model: AiModel,
|
|
input_size: int,
|
|
options: AiCallOptions,
|
|
rules: List) -> float:
|
|
"""Calculate score for a model based on rules and criteria."""
|
|
score = 0.0
|
|
|
|
# Check basic requirements
|
|
if not self._meetsBasicRequirements(model, input_size, options):
|
|
return 0.0
|
|
|
|
# Apply rules
|
|
for rule in rules:
|
|
rule_score = self._applyRule(model, input_size, options, rule)
|
|
score += rule_score * rule.weight
|
|
|
|
# Apply priority-based scoring
|
|
priority_score = self._applyPriorityScoring(model, options)
|
|
score += priority_score
|
|
|
|
# Apply processing mode scoring
|
|
mode_score = self._applyProcessingModeScoring(model, options)
|
|
score += mode_score
|
|
|
|
# Apply cost constraints
|
|
if not self._meetsCostConstraints(model, input_size, options):
|
|
score *= 0.1 # Heavily penalize but don't eliminate
|
|
|
|
return max(0.0, score)
|
|
|
|
def _meetsBasicRequirements(self, model: AiModel, input_size: int, options: AiCallOptions) -> bool:
|
|
"""Check if model meets basic requirements."""
|
|
# Context length check
|
|
if model.contextLength > 0 and input_size > model.contextLength * 0.8:
|
|
logger.debug(f"Model {model.name} rejected: input too large ({input_size} > {model.contextLength * 0.8})")
|
|
return False
|
|
|
|
# Required tags check
|
|
if options.requiredTags:
|
|
if not all(tag in model.tags for tag in options.requiredTags):
|
|
logger.debug(f"Model {model.name} rejected: missing required tags")
|
|
return False
|
|
|
|
# Capabilities check
|
|
if options.modelCapabilities:
|
|
if not all(cap in model.capabilities for cap in options.modelCapabilities):
|
|
logger.debug(f"Model {model.name} rejected: missing required capabilities")
|
|
return False
|
|
|
|
# Avoid tags check
|
|
for rule in self.config.getRulesForOperation(options.operationType):
|
|
if any(tag in model.tags for tag in rule.avoid_tags):
|
|
logger.debug(f"Model {model.name} rejected: has avoid tags")
|
|
return False
|
|
|
|
return True
|
|
|
|
def _applyRule(self, model: AiModel, input_size: int, options: AiCallOptions, rule) -> float:
|
|
"""Apply a specific rule to calculate score contribution."""
|
|
score = 0.0
|
|
|
|
# Required tags match
|
|
if all(tag in model.tags for tag in rule.required_tags):
|
|
score += 1.0
|
|
|
|
# Preferred tags match
|
|
preferred_matches = sum(1 for tag in rule.preferred_tags if tag in model.tags)
|
|
if rule.preferred_tags:
|
|
score += (preferred_matches / len(rule.preferred_tags)) * 0.5
|
|
|
|
# Quality rating check
|
|
if rule.min_quality_rating and model.qualityRating >= rule.min_quality_rating:
|
|
score += 0.3
|
|
|
|
# Context length check
|
|
if rule.min_context_length and model.contextLength >= rule.min_context_length:
|
|
score += 0.2
|
|
|
|
return score
|
|
|
|
def _applyPriorityScoring(self, model: AiModel, options: AiCallOptions) -> float:
|
|
"""Apply priority-based scoring."""
|
|
if options.priority == Priority.SPEED:
|
|
return model.speedRating * 0.1
|
|
elif options.priority == Priority.QUALITY:
|
|
return model.qualityRating * 0.1
|
|
elif options.priority == Priority.COST:
|
|
# Lower cost = higher score
|
|
cost_score = max(0, 1.0 - (model.costPer1kTokens * 1000))
|
|
return cost_score * 0.1
|
|
else: # BALANCED
|
|
return (model.qualityRating + model.speedRating) * 0.05
|
|
|
|
def _applyProcessingModeScoring(self, model: AiModel, options: AiCallOptions) -> float:
|
|
"""Apply processing mode scoring."""
|
|
if options.processingMode == ProcessingMode.DETAILED:
|
|
if ModelTags.HIGH_QUALITY in model.tags:
|
|
return 0.2
|
|
elif options.processingMode == ProcessingMode.BASIC:
|
|
if ModelTags.FAST in model.tags:
|
|
return 0.2
|
|
|
|
return 0.0
|
|
|
|
def _meetsCostConstraints(self, model: AiModel, input_size: int, options: AiCallOptions) -> bool:
|
|
"""Check if model meets cost constraints."""
|
|
if options.maxCost is None:
|
|
return True
|
|
|
|
# Estimate cost
|
|
estimated_tokens = input_size / 4
|
|
estimated_cost = (estimated_tokens / 1000) * model.costPer1kTokens
|
|
|
|
return estimated_cost <= options.maxCost
|
|
|
|
def _logSelectionDetails(self, model: AiModel, input_size: int, options: AiCallOptions):
|
|
"""Log detailed selection information."""
|
|
logger.info(f"Model Selection Details:")
|
|
logger.info(f" Selected: {model.displayName} ({model.name})")
|
|
logger.info(f" Connector: {model.connectorType}")
|
|
logger.info(f" Operation: {options.operationType}")
|
|
logger.info(f" Priority: {options.priority}")
|
|
logger.info(f" Processing Mode: {options.processingMode}")
|
|
logger.info(f" Input Size: {input_size} bytes")
|
|
logger.info(f" Context Length: {model.contextLength}")
|
|
logger.info(f" Max Tokens: {model.maxTokens}")
|
|
logger.info(f" Quality Rating: {model.qualityRating}/10")
|
|
logger.info(f" Speed Rating: {model.speedRating}/10")
|
|
logger.info(f" Cost: ${model.costPer1kTokens:.4f}/1k tokens")
|
|
logger.info(f" Capabilities: {', '.join(model.capabilities)}")
|
|
logger.info(f" Tags: {', '.join(model.tags)}")
|
|
|
|
def getFallbackCriteria(self, operation_type: str) -> Dict[str, Any]:
|
|
"""Get fallback selection criteria for an operation type."""
|
|
return self.config.getFallbackCriteria(operation_type)
|
|
|
|
def _selectWithFallbackCriteria(self,
|
|
available_models: List[AiModel],
|
|
fallback_criteria: Dict[str, Any],
|
|
input_size: int,
|
|
options: AiCallOptions) -> Optional[AiModel]:
|
|
"""Select model using fallback criteria when normal selection fails."""
|
|
logger.info("Using fallback criteria for model selection")
|
|
|
|
# Filter models by fallback criteria
|
|
candidates = []
|
|
for model in available_models:
|
|
if not model.isAvailable:
|
|
continue
|
|
|
|
# Check required tags
|
|
if fallback_criteria.get("required_tags"):
|
|
if not all(tag in model.tags for tag in fallback_criteria["required_tags"]):
|
|
continue
|
|
|
|
# Check quality rating
|
|
if fallback_criteria.get("min_quality_rating"):
|
|
if model.qualityRating < fallback_criteria["min_quality_rating"]:
|
|
continue
|
|
|
|
# Check cost
|
|
if fallback_criteria.get("max_cost_per_1k"):
|
|
if model.costPer1kTokens > fallback_criteria["max_cost_per_1k"]:
|
|
continue
|
|
|
|
# Check context length
|
|
if model.contextLength > 0 and input_size > model.contextLength * 0.8:
|
|
continue
|
|
|
|
candidates.append(model)
|
|
|
|
if not candidates:
|
|
logger.error("No models available even with fallback criteria")
|
|
return None
|
|
|
|
# Sort by priority order from fallback criteria
|
|
priority_order = fallback_criteria.get("priority_order", ["quality", "speed", "cost"])
|
|
|
|
def get_priority_score(model: AiModel) -> float:
|
|
score = 0.0
|
|
for i, priority in enumerate(priority_order):
|
|
weight = len(priority_order) - i # Higher weight for earlier priorities
|
|
if priority == "quality":
|
|
score += model.qualityRating * weight
|
|
elif priority == "speed":
|
|
score += model.speedRating * weight
|
|
elif priority == "cost":
|
|
# Lower cost = higher score
|
|
score += (1.0 - model.costPer1kTokens * 1000) * weight
|
|
return score
|
|
|
|
candidates.sort(key=get_priority_score, reverse=True)
|
|
selected_model = candidates[0]
|
|
|
|
logger.info(f"Fallback selection: {selected_model.name} (score: {get_priority_score(selected_model):.2f})")
|
|
return selected_model
|
|
|
|
|
|
# Global selector instance
|
|
model_selector = ModelSelector()
|