180 lines
7.4 KiB
Python
180 lines
7.4 KiB
Python
"""
|
|
Configuration for dynamic model selection rules.
|
|
This makes model selection configurable rather than hardcoded.
|
|
"""
|
|
|
|
from typing import Dict, List, Any, Optional
|
|
from dataclasses import dataclass
|
|
from modules.datamodels.datamodelAi import OperationType, Priority, ProcessingMode, ModelTags
|
|
|
|
|
|
@dataclass
|
|
class SelectionRule:
|
|
"""A rule for model selection."""
|
|
name: str
|
|
condition: str # Description of when this rule applies
|
|
weight: float # Weight for scoring (higher = more important)
|
|
operation_types: List[str] # Operation types this rule applies to
|
|
required_tags: List[str] # Required tags for this rule
|
|
preferred_tags: List[str] # Preferred tags for this rule
|
|
avoid_tags: List[str] # Tags to avoid for this rule
|
|
min_quality_rating: Optional[int] = None # Minimum quality rating
|
|
max_cost: Optional[float] = None # Maximum cost threshold
|
|
min_context_length: Optional[int] = None # Minimum context length required
|
|
|
|
|
|
class ModelSelectionConfig:
|
|
"""Configuration for model selection rules."""
|
|
|
|
def __init__(self):
|
|
self.rules = self._loadDefaultRules()
|
|
self.fallback_models = self._loadFallbackModels()
|
|
|
|
def _loadDefaultRules(self) -> List[SelectionRule]:
|
|
"""Load default selection rules."""
|
|
return [
|
|
# High quality for planning and analysis
|
|
SelectionRule(
|
|
name="high_quality_analysis",
|
|
condition="Planning or analysis operations requiring high quality",
|
|
weight=10.0,
|
|
operation_types=[OperationType.GENERATE_PLAN, OperationType.ANALYSE_CONTENT],
|
|
required_tags=[ModelTags.TEXT, ModelTags.REASONING, ModelTags.ANALYSIS],
|
|
preferred_tags=[ModelTags.HIGH_QUALITY],
|
|
avoid_tags=[ModelTags.FAST],
|
|
min_quality_rating=8
|
|
),
|
|
|
|
# Fast processing for basic operations
|
|
SelectionRule(
|
|
name="fast_basic_processing",
|
|
condition="Basic operations requiring speed",
|
|
weight=8.0,
|
|
operation_types=[OperationType.GENERAL],
|
|
required_tags=[ModelTags.TEXT, ModelTags.CHAT],
|
|
preferred_tags=[ModelTags.FAST],
|
|
avoid_tags=[],
|
|
min_quality_rating=5
|
|
),
|
|
|
|
# Cost-effective for high-volume operations
|
|
SelectionRule(
|
|
name="cost_effective_processing",
|
|
condition="High-volume operations where cost matters",
|
|
weight=7.0,
|
|
operation_types=[OperationType.GENERAL, OperationType.GENERATE_CONTENT],
|
|
required_tags=[ModelTags.TEXT],
|
|
preferred_tags=[ModelTags.COST_EFFECTIVE],
|
|
avoid_tags=[],
|
|
max_cost=0.01 # $0.01 per 1k tokens
|
|
),
|
|
|
|
# Image analysis specific
|
|
SelectionRule(
|
|
name="image_analysis",
|
|
condition="Image analysis operations",
|
|
weight=10.0,
|
|
operation_types=[OperationType.IMAGE_ANALYSIS],
|
|
required_tags=[ModelTags.IMAGE, ModelTags.VISION, ModelTags.MULTIMODAL],
|
|
preferred_tags=[ModelTags.HIGH_QUALITY],
|
|
avoid_tags=[],
|
|
min_quality_rating=8
|
|
),
|
|
|
|
# Web research specific
|
|
SelectionRule(
|
|
name="web_research",
|
|
condition="Web research operations",
|
|
weight=9.0,
|
|
operation_types=[OperationType.WEB_RESEARCH],
|
|
required_tags=[ModelTags.TEXT, ModelTags.ANALYSIS],
|
|
preferred_tags=[ModelTags.WEB, ModelTags.SEARCH],
|
|
avoid_tags=[],
|
|
min_quality_rating=7
|
|
),
|
|
|
|
# Large context requirements
|
|
SelectionRule(
|
|
name="large_context",
|
|
condition="Operations requiring large context",
|
|
weight=8.0,
|
|
operation_types=[OperationType.GENERAL, OperationType.ANALYSE_CONTENT],
|
|
required_tags=[ModelTags.TEXT],
|
|
preferred_tags=[],
|
|
avoid_tags=[],
|
|
min_context_length=100000 # 100k tokens
|
|
)
|
|
]
|
|
|
|
def _loadFallbackModels(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Load fallback model selection criteria."""
|
|
return {
|
|
OperationType.GENERAL: {
|
|
"priority_order": ["speed", "quality", "cost"],
|
|
"required_tags": [ModelTags.TEXT, ModelTags.CHAT],
|
|
"min_quality_rating": 5,
|
|
"max_cost_per_1k": 0.01
|
|
},
|
|
OperationType.IMAGE_ANALYSIS: {
|
|
"priority_order": ["quality", "speed"],
|
|
"required_tags": [ModelTags.IMAGE, ModelTags.VISION, ModelTags.MULTIMODAL],
|
|
"min_quality_rating": 8,
|
|
"max_cost_per_1k": 0.1
|
|
},
|
|
OperationType.IMAGE_GENERATION: {
|
|
"priority_order": ["quality", "speed"],
|
|
"required_tags": [ModelTags.IMAGE_GENERATION, ModelTags.ART, ModelTags.VISUAL],
|
|
"min_quality_rating": 8,
|
|
"max_cost_per_1k": 0.1
|
|
},
|
|
OperationType.WEB_RESEARCH: {
|
|
"priority_order": ["quality", "speed", "cost"],
|
|
"required_tags": [ModelTags.TEXT, ModelTags.ANALYSIS],
|
|
"preferred_tags": [ModelTags.WEB, ModelTags.SEARCH],
|
|
"min_quality_rating": 7,
|
|
"max_cost_per_1k": 0.02
|
|
},
|
|
OperationType.GENERATE_PLAN: {
|
|
"priority_order": ["quality", "speed"],
|
|
"required_tags": [ModelTags.TEXT, ModelTags.REASONING, ModelTags.ANALYSIS],
|
|
"preferred_tags": [ModelTags.HIGH_QUALITY],
|
|
"min_quality_rating": 8,
|
|
"max_cost_per_1k": 0.1
|
|
},
|
|
OperationType.ANALYSE_CONTENT: {
|
|
"priority_order": ["quality", "speed"],
|
|
"required_tags": [ModelTags.TEXT, ModelTags.ANALYSIS, ModelTags.REASONING],
|
|
"preferred_tags": [ModelTags.HIGH_QUALITY],
|
|
"min_quality_rating": 8,
|
|
"max_cost_per_1k": 0.1
|
|
}
|
|
}
|
|
|
|
def getRulesForOperation(self, operation_type: str) -> List[SelectionRule]:
|
|
"""Get rules that apply to a specific operation type."""
|
|
return [rule for rule in self.rules if operation_type in rule.operation_types]
|
|
|
|
def getFallbackCriteria(self, operation_type: str) -> Dict[str, Any]:
|
|
"""Get fallback selection criteria for a specific operation type."""
|
|
return self.fallback_models.get(operation_type, self.fallback_models[OperationType.GENERAL])
|
|
|
|
def addRule(self, rule: SelectionRule):
|
|
"""Add a new selection rule."""
|
|
self.rules.append(rule)
|
|
|
|
def removeRule(self, rule_name: str):
|
|
"""Remove a selection rule by name."""
|
|
self.rules = [rule for rule in self.rules if rule.name != rule_name]
|
|
|
|
def updateRule(self, rule_name: str, **kwargs):
|
|
"""Update an existing rule."""
|
|
for rule in self.rules:
|
|
if rule.name == rule_name:
|
|
for key, value in kwargs.items():
|
|
if hasattr(rule, key):
|
|
setattr(rule, key, value)
|
|
break
|
|
|
|
|
|
# Global configuration instance
|
|
model_selection_config = ModelSelectionConfig()
|