gateway/modules/aicore/aicoreModelSelector.py

279 lines
12 KiB
Python

"""
Dynamic model selector using configurable rules and scoring.
"""
import logging
from typing import List, Optional, Dict, Any, Tuple
from modules.datamodels.datamodelAi import AiModel, AiCallOptions, OperationType, Priority, ProcessingMode, ModelTags
from modules.aicore.aicoreModelSelectionConfig import model_selection_config
logger = logging.getLogger(__name__)
class ModelSelector:
"""Dynamic model selector using configurable rules."""
def __init__(self):
self.config = model_selection_config
def selectModel(self,
prompt: str,
context: str,
options: AiCallOptions,
available_models: List[AiModel]) -> Optional[AiModel]:
"""
Select the best model based on configurable rules and scoring.
Args:
prompt: User prompt
context: Context data
options: AI call options
available_models: List of available models to choose from
Returns:
Selected model or None if no suitable model found
"""
if not available_models:
logger.warning("No models available for selection")
return None
logger.info(f"Selecting model for operation: {options.operationType}, priority: {options.priority}")
# Calculate input size
input_size = len(prompt.encode("utf-8")) + len(context.encode("utf-8"))
# Get applicable rules
rules = self.config.getRulesForOperation(options.operationType)
logger.debug(f"Found {len(rules)} applicable rules for {options.operationType}")
# Score each model
scored_models = []
for model in available_models:
if not model.isAvailable:
continue
score = self._calculateModelScore(model, input_size, options, rules)
if score > 0: # Only consider models with positive scores
scored_models.append((model, score))
logger.debug(f"Model {model.name}: score={score:.2f}")
if not scored_models:
logger.warning("No models passed the selection criteria, trying fallback criteria")
# Try fallback criteria
fallback_criteria = self.getFallbackCriteria(options.operationType)
return self._selectWithFallbackCriteria(available_models, fallback_criteria, input_size, options)
# Sort by score (highest first)
scored_models.sort(key=lambda x: x[1], reverse=True)
selected_model = scored_models[0][0]
selected_score = scored_models[0][1]
logger.info(f"Selected model: {selected_model.name} (score: {selected_score:.2f})")
# Log selection details
self._logSelectionDetails(selected_model, input_size, options)
return selected_model
def _calculateModelScore(self,
model: AiModel,
input_size: int,
options: AiCallOptions,
rules: List) -> float:
"""Calculate score for a model based on rules and criteria."""
score = 0.0
# Check basic requirements
if not self._meetsBasicRequirements(model, input_size, options):
return 0.0
# Apply rules
for rule in rules:
rule_score = self._applyRule(model, input_size, options, rule)
score += rule_score * rule.weight
# Apply priority-based scoring
priority_score = self._applyPriorityScoring(model, options)
score += priority_score
# Apply processing mode scoring
mode_score = self._applyProcessingModeScoring(model, options)
score += mode_score
# Apply cost constraints
if not self._meetsCostConstraints(model, input_size, options):
score *= 0.1 # Heavily penalize but don't eliminate
return max(0.0, score)
def _meetsBasicRequirements(self, model: AiModel, input_size: int, options: AiCallOptions) -> bool:
"""Check if model meets basic requirements."""
# Context length check
if model.contextLength > 0 and input_size > model.contextLength * 0.8:
logger.debug(f"Model {model.name} rejected: input too large ({input_size} > {model.contextLength * 0.8})")
return False
# Required tags check
if options.requiredTags:
if not all(tag in model.tags for tag in options.requiredTags):
logger.debug(f"Model {model.name} rejected: missing required tags")
return False
# Capabilities check
if options.modelCapabilities:
if not all(cap in model.capabilities for cap in options.modelCapabilities):
logger.debug(f"Model {model.name} rejected: missing required capabilities")
return False
# Avoid tags check
for rule in self.config.getRulesForOperation(options.operationType):
if any(tag in model.tags for tag in rule.avoid_tags):
logger.debug(f"Model {model.name} rejected: has avoid tags")
return False
return True
def _applyRule(self, model: AiModel, input_size: int, options: AiCallOptions, rule) -> float:
"""Apply a specific rule to calculate score contribution."""
score = 0.0
# Required tags match
if all(tag in model.tags for tag in rule.required_tags):
score += 1.0
# Preferred tags match
preferred_matches = sum(1 for tag in rule.preferred_tags if tag in model.tags)
if rule.preferred_tags:
score += (preferred_matches / len(rule.preferred_tags)) * 0.5
# Quality rating check
if rule.min_quality_rating and model.qualityRating >= rule.min_quality_rating:
score += 0.3
# Context length check
if rule.min_context_length and model.contextLength >= rule.min_context_length:
score += 0.2
return score
def _applyPriorityScoring(self, model: AiModel, options: AiCallOptions) -> float:
"""Apply priority-based scoring."""
if options.priority == Priority.SPEED:
return model.speedRating * 0.1
elif options.priority == Priority.QUALITY:
return model.qualityRating * 0.1
elif options.priority == Priority.COST:
# Lower cost = higher score
cost_score = max(0, 1.0 - (model.costPer1kTokens * 1000))
return cost_score * 0.1
else: # BALANCED
return (model.qualityRating + model.speedRating) * 0.05
def _applyProcessingModeScoring(self, model: AiModel, options: AiCallOptions) -> float:
"""Apply processing mode scoring."""
if options.processingMode == ProcessingMode.DETAILED:
if ModelTags.HIGH_QUALITY in model.tags:
return 0.2
elif options.processingMode == ProcessingMode.BASIC:
if ModelTags.FAST in model.tags:
return 0.2
return 0.0
def _meetsCostConstraints(self, model: AiModel, input_size: int, options: AiCallOptions) -> bool:
"""Check if model meets cost constraints."""
if options.maxCost is None:
return True
# Estimate cost
estimated_tokens = input_size / 4
estimated_cost = (estimated_tokens / 1000) * model.costPer1kTokens
return estimated_cost <= options.maxCost
def _logSelectionDetails(self, model: AiModel, input_size: int, options: AiCallOptions):
"""Log detailed selection information."""
logger.info(f"Model Selection Details:")
logger.info(f" Selected: {model.displayName} ({model.name})")
logger.info(f" Connector: {model.connectorType}")
logger.info(f" Operation: {options.operationType}")
logger.info(f" Priority: {options.priority}")
logger.info(f" Processing Mode: {options.processingMode}")
logger.info(f" Input Size: {input_size} bytes")
logger.info(f" Context Length: {model.contextLength}")
logger.info(f" Max Tokens: {model.maxTokens}")
logger.info(f" Quality Rating: {model.qualityRating}/10")
logger.info(f" Speed Rating: {model.speedRating}/10")
logger.info(f" Cost: ${model.costPer1kTokens:.4f}/1k tokens")
logger.info(f" Capabilities: {', '.join(model.capabilities)}")
logger.info(f" Tags: {', '.join(model.tags)}")
def getFallbackCriteria(self, operation_type: str) -> Dict[str, Any]:
"""Get fallback selection criteria for an operation type."""
return self.config.getFallbackCriteria(operation_type)
def _selectWithFallbackCriteria(self,
available_models: List[AiModel],
fallback_criteria: Dict[str, Any],
input_size: int,
options: AiCallOptions) -> Optional[AiModel]:
"""Select model using fallback criteria when normal selection fails."""
logger.info("Using fallback criteria for model selection")
# Filter models by fallback criteria
candidates = []
for model in available_models:
if not model.isAvailable:
continue
# Check required tags
if fallback_criteria.get("required_tags"):
if not all(tag in model.tags for tag in fallback_criteria["required_tags"]):
continue
# Check quality rating
if fallback_criteria.get("min_quality_rating"):
if model.qualityRating < fallback_criteria["min_quality_rating"]:
continue
# Check cost
if fallback_criteria.get("max_cost_per_1k"):
if model.costPer1kTokens > fallback_criteria["max_cost_per_1k"]:
continue
# Check context length
if model.contextLength > 0 and input_size > model.contextLength * 0.8:
continue
candidates.append(model)
if not candidates:
logger.error("No models available even with fallback criteria")
return None
# Sort by priority order from fallback criteria
priority_order = fallback_criteria.get("priority_order", ["quality", "speed", "cost"])
def get_priority_score(model: AiModel) -> float:
score = 0.0
for i, priority in enumerate(priority_order):
weight = len(priority_order) - i # Higher weight for earlier priorities
if priority == "quality":
score += model.qualityRating * weight
elif priority == "speed":
score += model.speedRating * weight
elif priority == "cost":
# Lower cost = higher score
score += (1.0 - model.costPer1kTokens * 1000) * weight
return score
candidates.sort(key=get_priority_score, reverse=True)
selected_model = candidates[0]
logger.info(f"Fallback selection: {selected_model.name} (score: {get_priority_score(selected_model):.2f})")
return selected_model
# Global selector instance
model_selector = ModelSelector()