updated price calculations and basis for refactory of dynamic content handling
This commit is contained in:
parent
109e77fd60
commit
6b819cc848
9 changed files with 222 additions and 195 deletions
|
|
@ -172,4 +172,4 @@ class ModelRegistry:
|
|||
|
||||
|
||||
# Global registry instance
|
||||
model_registry = ModelRegistry()
|
||||
modelRegistry = ModelRegistry()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import time
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from modules.aicore.aicoreModelRegistry import model_registry
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.aicore.aicoreModelSelector import model_selector
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiModel,
|
||||
|
|
@ -42,11 +42,11 @@ class AiObjects:
|
|||
logger.info("Auto-discovering AI connectors...")
|
||||
|
||||
# Use the model registry's built-in discovery mechanism
|
||||
discoveredConnectors = model_registry.discoverConnectors()
|
||||
discoveredConnectors = modelRegistry.discoverConnectors()
|
||||
|
||||
# Register each discovered connector
|
||||
for connector in discoveredConnectors:
|
||||
model_registry.registerConnector(connector)
|
||||
modelRegistry.registerConnector(connector)
|
||||
logger.info(f"Registered connector: {connector.getConnectorType()}")
|
||||
|
||||
logger.info(f"Total connectors registered: {len(discoveredConnectors)}")
|
||||
|
|
@ -63,7 +63,7 @@ class AiObjects:
|
|||
def _selectModel(self, prompt: str, context: str, options: AiCallOptions) -> str:
|
||||
"""Select the best model using dynamic model selection system."""
|
||||
# Get available models from the dynamic registry
|
||||
availableModels = model_registry.getAvailableModels()
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
if not availableModels:
|
||||
logger.error("No models available in the registry")
|
||||
|
|
@ -109,7 +109,7 @@ class AiObjects:
|
|||
maxTokens = getattr(options, "maxTokens", None)
|
||||
|
||||
# Get fallback models for this operation type
|
||||
availableModels = model_registry.getAvailableModels()
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
fallbackModels = model_selector.getFallbackModels(prompt, context, options, availableModels)
|
||||
|
||||
if not fallbackModels:
|
||||
|
|
@ -188,7 +188,7 @@ class AiObjects:
|
|||
startTime = time.time()
|
||||
|
||||
# Get the connector for this model
|
||||
connector = model_registry.getConnectorForModel(model.name)
|
||||
connector = modelRegistry.getConnectorForModel(model.name)
|
||||
if not connector:
|
||||
raise ValueError(f"No connector found for model {model.name}")
|
||||
|
||||
|
|
@ -221,8 +221,8 @@ class AiObjects:
|
|||
processingTime = endTime - startTime
|
||||
outputBytes = len(content.encode("utf-8"))
|
||||
|
||||
# Calculate price using model's cost information
|
||||
priceUsd = model.costPer1kTokensInput * (inputBytes / 4 / 1000) + model.costPer1kTokensOutput * (outputBytes / 4 / 1000)
|
||||
# Calculate price using model's own price calculation method
|
||||
priceUsd = model.calculatePriceUsd(inputBytes, outputBytes)
|
||||
|
||||
return AiCallResponse(
|
||||
content=content,
|
||||
|
|
@ -244,7 +244,7 @@ class AiObjects:
|
|||
inputBytes = len(prompt.encode("utf-8")) + len(imageData) if isinstance(imageData, bytes) else len(prompt.encode("utf-8")) + len(str(imageData).encode("utf-8"))
|
||||
|
||||
# Get fallback models for image analysis
|
||||
availableModels = model_registry.getAvailableModels()
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
fallbackModels = model_selector.getFallbackModels(prompt, "", options, availableModels)
|
||||
|
||||
if not fallbackModels:
|
||||
|
|
@ -314,8 +314,8 @@ class AiObjects:
|
|||
processingTime = endTime - startTime
|
||||
outputBytes = len(content.encode("utf-8"))
|
||||
|
||||
# Calculate price using model's cost information
|
||||
priceUsd = model.costPer1kTokensInput * (inputBytes / 4 / 1000) + model.costPer1kTokensOutput * (outputBytes / 4 / 1000)
|
||||
# Calculate price using model's own price calculation method
|
||||
priceUsd = model.calculatePriceUsd(inputBytes, outputBytes)
|
||||
|
||||
return AiCallResponse(
|
||||
content=content,
|
||||
|
|
@ -339,13 +339,13 @@ class AiObjects:
|
|||
try:
|
||||
# Select the best model for image generation
|
||||
modelName = self._selectModel(prompt, "", options)
|
||||
selectedModel = model_registry.getModel(modelName)
|
||||
selectedModel = modelRegistry.getModel(modelName)
|
||||
|
||||
if not selectedModel:
|
||||
raise ValueError(f"Selected model {modelName} not found in registry")
|
||||
|
||||
# Get the connector for this model
|
||||
connector = model_registry.getConnectorForModel(modelName)
|
||||
connector = modelRegistry.getConnectorForModel(modelName)
|
||||
if not connector:
|
||||
raise ValueError(f"No connector found for model {modelName}")
|
||||
|
||||
|
|
@ -364,9 +364,8 @@ class AiObjects:
|
|||
processingTime = endTime - startTime
|
||||
outputBytes = len(content.encode("utf-8"))
|
||||
|
||||
# Calculate price using model's cost information
|
||||
estimatedTokens = inputBytes / 4
|
||||
priceUsd = (estimatedTokens / 1000) * selectedModel.costPer1kTokensInput + (outputBytes / 4 / 1000) * selectedModel.costPer1kTokensOutput
|
||||
# Calculate price using model's own price calculation method
|
||||
priceUsd = selectedModel.calculatePriceUsd(inputBytes, outputBytes)
|
||||
|
||||
logger.info(f"✅ Image generation successful with model: {modelName}")
|
||||
return AiCallResponse(
|
||||
|
|
@ -401,7 +400,7 @@ class AiObjects:
|
|||
**kwargs
|
||||
)
|
||||
# Get Tavily connector from registry
|
||||
tavilyConnector = model_registry.getConnectorForModel("tavily_search")
|
||||
tavilyConnector = modelRegistry.getConnectorForModel("tavily_search")
|
||||
if not tavilyConnector:
|
||||
raise ValueError("Tavily connector not available")
|
||||
result = await tavilyConnector.search(request)
|
||||
|
|
@ -440,7 +439,7 @@ class AiObjects:
|
|||
format=format
|
||||
)
|
||||
# Get Tavily connector from registry
|
||||
tavilyConnector = model_registry.getConnectorForModel("tavily_crawl")
|
||||
tavilyConnector = modelRegistry.getConnectorForModel("tavily_crawl")
|
||||
if not tavilyConnector:
|
||||
raise ValueError("Tavily connector not available")
|
||||
result = await tavilyConnector.crawl(request)
|
||||
|
|
@ -792,7 +791,7 @@ Format your response in a clear, professional manner that would be helpful for s
|
|||
startTime = time.time()
|
||||
|
||||
# Use Perplexity for web research with search capabilities
|
||||
perplexity_connector = model_registry.getConnectorForModel("perplexity_callAiWithWebSearch")
|
||||
perplexity_connector = modelRegistry.getConnectorForModel("perplexity_callAiWithWebSearch")
|
||||
if not perplexity_connector:
|
||||
raise ValueError("Perplexity connector not available")
|
||||
response = await perplexity_connector.callAiWithWebSearch(webPrompt)
|
||||
|
|
@ -803,12 +802,8 @@ Format your response in a clear, professional manner that would be helpful for s
|
|||
outputBytes = len(response.encode("utf-8"))
|
||||
|
||||
# Calculate price using Perplexity model pricing
|
||||
perplexity_model = model_registry.getModel("perplexity_callAiWithWebSearch")
|
||||
if perplexity_model:
|
||||
estimated_tokens = inputBytes / 4
|
||||
priceUsd = (estimated_tokens / 1000) * perplexity_model.costPer1kTokensInput + (outputBytes / 4 / 1000) * perplexity_model.costPer1kTokensOutput
|
||||
else:
|
||||
priceUsd = 0.0
|
||||
perplexity_model = modelRegistry.getModel("perplexity_callAiWithWebSearch")
|
||||
priceUsd = perplexity_model.calculatePriceUsd(inputBytes, outputBytes)
|
||||
|
||||
logger.info(f"✅ Web query successful with Perplexity")
|
||||
return AiCallResponse(
|
||||
|
|
@ -835,26 +830,26 @@ Format your response in a clear, professional manner that would be helpful for s
|
|||
# Utility methods
|
||||
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
|
||||
"""List available models, optionally filtered by connector type."""
|
||||
models = model_registry.getAvailableModels()
|
||||
models = modelRegistry.getAvailableModels()
|
||||
if connectorType:
|
||||
return [model.dict() for model in models if model.connectorType == connectorType]
|
||||
return [model.dict() for model in models]
|
||||
|
||||
async def getModelInfo(self, modelName: str) -> Dict[str, Any]:
|
||||
"""Get information about a specific model."""
|
||||
model = model_registry.getModel(modelName)
|
||||
model = modelRegistry.getModel(modelName)
|
||||
if not model:
|
||||
raise ValueError(f"Model {modelName} not found")
|
||||
return model.dict()
|
||||
|
||||
async def getModelsByCapability(self, capability: str) -> List[str]:
|
||||
"""Get model names that support a specific capability."""
|
||||
models = model_registry.getModelsByCapability(capability)
|
||||
models = modelRegistry.getModelsByCapability(capability)
|
||||
return [model.name for model in models]
|
||||
|
||||
async def getModelsByTag(self, tag: str) -> List[str]:
|
||||
"""Get model names that have a specific tag."""
|
||||
models = model_registry.getModelsByTag(tag)
|
||||
models = modelRegistry.getModelsByTag(tag)
|
||||
return [model.name for model in models]
|
||||
|
||||
async def selectRelevantWebsites(self, websites: List[str], userQuestion: str) -> Tuple[List[str], str]:
|
||||
|
|
|
|||
|
|
@ -547,81 +547,3 @@ CRITICAL REQUIREMENTS:
|
|||
except Exception as e:
|
||||
logger.error(f"Error in AI image generation: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
|
||||
def _getModelCapabilitiesForContent(self, prompt: str, documents: Optional[List[ChatDocument]], options: AiCallOptions) -> Dict[str, int]:
|
||||
"""
|
||||
Get model capabilities for content processing, including appropriate size limits for chunking.
|
||||
"""
|
||||
# Estimate total content size
|
||||
prompt_size = len(prompt.encode('utf-8'))
|
||||
document_size = 0
|
||||
if documents:
|
||||
# Rough estimate of document content size
|
||||
for doc in documents:
|
||||
document_size += doc.fileSize or 0
|
||||
|
||||
total_size = prompt_size + document_size
|
||||
|
||||
# Use AiObjects to select the best model for this content size
|
||||
# We'll simulate the model selection by checking available models
|
||||
from modules.interfaces.interfaceAiObjects import aiModels
|
||||
|
||||
# Find the best model for this content size and operation
|
||||
best_model = None
|
||||
best_context_length = 0
|
||||
|
||||
for model_name, model_info in aiModels.items():
|
||||
context_length = model_info.get("contextLength", 0)
|
||||
|
||||
# Skip models with no context length or too small for content
|
||||
if context_length == 0:
|
||||
continue
|
||||
|
||||
# Check if model supports the operation type
|
||||
capabilities = model_info.get("capabilities", [])
|
||||
if options.operationType == OperationTypeEnum.IMAGE_ANALYSE and "imageAnalyse" not in capabilities:
|
||||
continue
|
||||
elif options.operationType == OperationTypeEnum.IMAGE_GENERATE and "imageGenerate" not in capabilities:
|
||||
continue
|
||||
elif options.operationType == OperationTypeEnum.WEB_RESEARCH and "web_search" not in capabilities:
|
||||
continue
|
||||
elif "text_generation" not in capabilities:
|
||||
continue
|
||||
|
||||
# Prefer models that can handle the content without chunking, but allow chunking if needed
|
||||
if context_length >= total_size * 0.8: # 80% of content size
|
||||
if context_length > best_context_length:
|
||||
best_model = model_info
|
||||
best_context_length = context_length
|
||||
elif best_model is None: # Fallback to largest available model
|
||||
if context_length > best_context_length:
|
||||
best_model = model_info
|
||||
best_context_length = context_length
|
||||
|
||||
# Fallback to a reasonable default if no model found
|
||||
if best_model is None:
|
||||
best_model = {
|
||||
"contextLength": 128000, # GPT-4o default
|
||||
"llmName": "gpt-4o"
|
||||
}
|
||||
|
||||
# Calculate appropriate sizes
|
||||
# Convert tokens to bytes (rough estimate: 1 token ≈ 4 characters)
|
||||
context_length_bytes = int(best_model["contextLength"] * 4)
|
||||
max_context_bytes = int(context_length_bytes * 0.9) # 90% of context length
|
||||
text_chunk_size = int(max_context_bytes * 0.7) # 70% of max context for text chunks
|
||||
image_chunk_size = int(max_context_bytes * 0.8) # 80% of max context for image chunks
|
||||
|
||||
logger.debug(f"Selected model: {best_model.get('llmName', 'unknown')} with context length: {best_model['contextLength']}")
|
||||
logger.debug(f"Content size: {total_size} bytes, Max context: {max_context_bytes} bytes")
|
||||
logger.debug(f"Text chunk size: {text_chunk_size} bytes, Image chunk size: {image_chunk_size} bytes")
|
||||
|
||||
return {
|
||||
"maxContextBytes": max_context_bytes,
|
||||
"textChunkSize": text_chunk_size,
|
||||
"imageChunkSize": image_chunk_size
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -503,8 +503,6 @@ CONTINUATION INSTRUCTIONS:
|
|||
Handle text calls with document processing through ExtractionService.
|
||||
UNIFIED PROCESSING: Always use per-chunk processing for consistency.
|
||||
"""
|
||||
# UNIFIED PROCESSING: Always use per-chunk processing for consistency
|
||||
# This ensures MIME-type checking, chunk mapping, and parallel processing
|
||||
return await self.processDocumentsPerChunk(documents, prompt, options)
|
||||
|
||||
async def _processChunksWithMapping(
|
||||
|
|
@ -1194,6 +1192,7 @@ CONTINUATION INSTRUCTIONS:
|
|||
def _getModelCapabilitiesForContent(self, prompt: str, documents: Optional[List[ChatDocument]], options: AiCallOptions) -> Dict[str, int]:
|
||||
"""
|
||||
Get model capabilities for content processing, including appropriate size limits for chunking.
|
||||
Uses centralized model selection to determine chunking parameters.
|
||||
"""
|
||||
# Estimate total content size
|
||||
prompt_size = len(prompt.encode('utf-8'))
|
||||
|
|
@ -1205,57 +1204,38 @@ CONTINUATION INSTRUCTIONS:
|
|||
|
||||
total_size = prompt_size + document_size
|
||||
|
||||
# Use AiObjects to select the best model for this content size
|
||||
# We'll simulate the model selection by checking available models
|
||||
from modules.interfaces.interfaceAiObjects import aiModels
|
||||
|
||||
# Find the best model for this content size and operation
|
||||
best_model = None
|
||||
best_context_length = 0
|
||||
|
||||
for model_name, model_info in aiModels.items():
|
||||
context_length = model_info.get("contextLength", 0)
|
||||
# Use centralized model selection to get the best model for chunking parameters
|
||||
try:
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.aicore.aicoreModelSelector import model_selector
|
||||
|
||||
# Skip models with no context length or too small for content
|
||||
if context_length == 0:
|
||||
continue
|
||||
# Get available models and select the best one for this operation
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
selectedModel = model_selector.selectModel(prompt, "", options, availableModels)
|
||||
|
||||
if selectedModel:
|
||||
context_length = selectedModel.contextLength
|
||||
model_name = selectedModel.name
|
||||
logger.debug(f"Selected model for chunking: {model_name} with context length: {context_length}")
|
||||
else:
|
||||
# Fallback to conservative default if no model selected
|
||||
context_length = 128000 # GPT-4o default
|
||||
model_name = "fallback"
|
||||
logger.warning(f"No model selected for chunking, using fallback context length: {context_length}")
|
||||
|
||||
# Check if model supports the operation type
|
||||
capabilities = model_info.get("capabilities", [])
|
||||
if options.operationType == OperationTypeEnum.IMAGE_ANALYSE and "imageAnalyse" not in capabilities:
|
||||
continue
|
||||
elif options.operationType == OperationTypeEnum.IMAGE_GENERATE and "imageGenerate" not in capabilities:
|
||||
continue
|
||||
elif options.operationType == OperationTypeEnum.WEB_RESEARCH and "web_search" not in capabilities:
|
||||
continue
|
||||
elif "text_generation" not in capabilities:
|
||||
continue
|
||||
|
||||
# Prefer models that can handle the content without chunking, but allow chunking if needed
|
||||
if context_length >= total_size * 0.8: # 80% of content size
|
||||
if context_length > best_context_length:
|
||||
best_model = model_info
|
||||
best_context_length = context_length
|
||||
elif best_model is None: # Fallback to largest available model
|
||||
if context_length > best_context_length:
|
||||
best_model = model_info
|
||||
best_context_length = context_length
|
||||
|
||||
# Fallback to a reasonable default if no model found
|
||||
if best_model is None:
|
||||
best_model = {
|
||||
"contextLength": 128000, # GPT-4o default
|
||||
"llmName": "gpt-4o"
|
||||
}
|
||||
except Exception as e:
|
||||
# Fallback to conservative default if model selection fails
|
||||
context_length = 128000 # GPT-4o default
|
||||
model_name = "fallback"
|
||||
logger.error(f"Model selection failed for chunking: {e}, using fallback context length: {context_length}")
|
||||
|
||||
# Calculate appropriate sizes
|
||||
# Convert tokens to bytes (rough estimate: 1 token ≈ 4 characters)
|
||||
context_length_bytes = int(best_model["contextLength"] * 4)
|
||||
context_length_bytes = int(context_length * 4)
|
||||
max_context_bytes = int(context_length_bytes * 0.9) # 90% of context length
|
||||
text_chunk_size = int(max_context_bytes * 0.7) # 70% of max context for text chunks
|
||||
image_chunk_size = int(max_context_bytes * 0.8) # 80% of max context for image chunks
|
||||
|
||||
logger.debug(f"Selected model: {best_model.get('llmName', 'unknown')} with context length: {best_model['contextLength']}")
|
||||
logger.debug(f"Content size: {total_size} bytes, Max context: {max_context_bytes} bytes")
|
||||
logger.debug(f"Text chunk size: {text_chunk_size} bytes, Image chunk size: {image_chunk_size} bytes")
|
||||
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class ImageChunker(Chunker):
|
|||
|
||||
return chunks
|
||||
|
||||
def _tileImage(self, image: "Image.Image", maxBytes: int, tileSize: int, quality: int, originalPixels: int) -> List[Dict[str, Any]]:
|
||||
def _tileImage(self, image: Any, maxBytes: int, tileSize: int, quality: int, originalPixels: int) -> List[Dict[str, Any]]:
|
||||
"""Split image into tiles if it's still too large after compression."""
|
||||
chunks = []
|
||||
width, height = image.size
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import logging
|
|||
import time
|
||||
|
||||
from .subRegistry import ExtractorRegistry, ChunkerRegistry
|
||||
from .subPipeline import runExtraction, poolAndLimit, applyAiIfRequested
|
||||
from .subPipeline import runExtraction
|
||||
from modules.datamodels.datamodelExtraction import ContentExtracted, ContentPart, MergeStrategy
|
||||
from modules.datamodels.datamodelChat import ChatDocument
|
||||
from modules.datamodels.datamodelAi import AiCallResponse
|
||||
from modules.interfaces.interfaceAiObjects import aiModels
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -91,8 +91,6 @@ class ExtractionService:
|
|||
else:
|
||||
logger.debug(f"No chunking needed - {len(ec.parts)} parts fit within size limits")
|
||||
|
||||
ec = applyAiIfRequested(ec, options)
|
||||
|
||||
# Calculate timing and emit stats
|
||||
endTime = time.time()
|
||||
processingTime = endTime - startTime
|
||||
|
|
@ -103,7 +101,8 @@ class ExtractionService:
|
|||
|
||||
# Use internal extraction model for pricing
|
||||
modelName = "internal_extraction"
|
||||
priceUsd = aiModels[modelName]["calculatePriceUsd"](processingTime, bytesSent, bytesReceived)
|
||||
model = modelRegistry.getModel(modelName)
|
||||
priceUsd = model.calculatePriceUsd(processingTime, bytesSent, bytesReceived)
|
||||
|
||||
# Create AiCallResponse with real calculation
|
||||
aiResponse = AiCallResponse(
|
||||
|
|
|
|||
|
|
@ -298,37 +298,3 @@ def _applySizeLimit(parts: List[ContentPart], maxSize: int) -> List[ContentPart]
|
|||
|
||||
return kept
|
||||
|
||||
|
||||
def applyAiIfRequested(extracted: ContentExtracted, options: Dict[str, Any]) -> ContentExtracted:
|
||||
"""
|
||||
Apply AI processing if requested in options.
|
||||
This is a placeholder for actual AI integration.
|
||||
"""
|
||||
prompt = options.get("prompt")
|
||||
operationType = options.get("operationType", "general")
|
||||
|
||||
if not prompt:
|
||||
return extracted
|
||||
|
||||
# Placeholder AI processing based on operationType
|
||||
if operationType == "analyse":
|
||||
# Add analysis metadata to parts
|
||||
for part in extracted.parts:
|
||||
if part.typeGroup in ("text", "table", "structure"):
|
||||
part.metadata["ai_processed"] = True
|
||||
part.metadata["operation_type"] = operationType
|
||||
elif operationType == "plan":
|
||||
# Add plan generation metadata
|
||||
for part in extracted.parts:
|
||||
if part.typeGroup == "text":
|
||||
part.metadata["ai_processed"] = True
|
||||
part.metadata["operation_type"] = operationType
|
||||
elif operationType == "generate":
|
||||
# Add content generation metadata
|
||||
for part in extracted.parts:
|
||||
part.metadata["ai_processed"] = True
|
||||
part.metadata["operation_type"] = operationType
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
from modules.datamodels.datamodelChat import ChatDocument
|
||||
from modules.datamodels.datamodelAi import AiCallResponse
|
||||
from modules.interfaces.interfaceAiObjects import aiModels
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.services.serviceGeneration.subDocumentUtility import (
|
||||
getFileExtension,
|
||||
getMimeTypeFromExtension,
|
||||
|
|
@ -430,7 +430,8 @@ class GenerationService:
|
|||
|
||||
# Use internal generation model for pricing
|
||||
modelName = "internal_generation"
|
||||
priceUsd = aiModels[modelName]["calculatePriceUsd"](processingTime, 0, bytesReceived)
|
||||
model = modelRegistry.getModel(modelName)
|
||||
priceUsd = model.calculatePriceUsd(processingTime, 0, bytesReceived)
|
||||
|
||||
aiResponse = AiCallResponse(
|
||||
content="", # No content for generation stats needed
|
||||
|
|
@ -457,7 +458,8 @@ class GenerationService:
|
|||
|
||||
# Use internal generation model for pricing
|
||||
modelName = "internal_generation"
|
||||
priceUsd = aiModels[modelName]["calculatePriceUsd"](processingTime, 0, 0)
|
||||
model = modelRegistry.getModel(modelName)
|
||||
priceUsd = model.calculatePriceUsd(processingTime, 0, 0)
|
||||
|
||||
aiResponse = AiCallResponse(
|
||||
content="", # No content for generation stats needed
|
||||
|
|
|
|||
163
test_ai_model_selection.py
Normal file
163
test_ai_model_selection.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
AI Model Selection Test - Prints prioritized fallback model lists used for AI calls
|
||||
|
||||
Scenarios mirror typical calls in workflows/ (task planning, action planning,
|
||||
analysis, and react-mode decisions), showing which models are shortlisted and
|
||||
their final prioritized order after rating and cost tie-breaking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
# Ensure gateway is on path when running directly
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from modules.features.chatPlayground.mainChatPlayground import getServices
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiCallOptions,
|
||||
OperationTypeEnum,
|
||||
PriorityEnum,
|
||||
ProcessingModeEnum,
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.aicore.aicoreModelSelector import model_selector
|
||||
|
||||
|
||||
class ModelSelectionTester:
|
||||
def __init__(self) -> None:
|
||||
testUser = User(
|
||||
id="test_user_models",
|
||||
username="test_models",
|
||||
email="test@example.com",
|
||||
fullName="Test Models",
|
||||
language="en",
|
||||
mandateId="test_mandate",
|
||||
)
|
||||
self.services = getServices(testUser, None)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
from modules.services.serviceAi.mainServiceAi import AiService
|
||||
|
||||
self.services.ai = await AiService.create(self.services)
|
||||
|
||||
async def _printFallbackList(self, title: str, prompt: str, options: AiCallOptions) -> None:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"{title}")
|
||||
print(f"{'='*80}")
|
||||
print(
|
||||
f"Operation={options.operationType.name}, Priority={options.priority.name}, ProcessingMode={options.processingMode.name}"
|
||||
)
|
||||
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
fallbackModels = model_selector.getFallbackModels(
|
||||
prompt=prompt,
|
||||
context="",
|
||||
options=options,
|
||||
availableModels=availableModels,
|
||||
)
|
||||
|
||||
if not fallbackModels:
|
||||
print("No suitable models found (capability filter returned empty list).")
|
||||
return
|
||||
|
||||
print("Prioritized fallback model sequence (name | quality | speed | $/1k in | ctx):")
|
||||
for idx, m in enumerate(fallbackModels, 1):
|
||||
costIn = getattr(m, "costPer1kTokensInput", 0.0)
|
||||
print(
|
||||
f" {idx:>2}. {m.name} | Q={getattr(m, 'qualityRating', 0)} | S={getattr(m, 'speedRating', 0)} | ${costIn:.4f} | ctx={getattr(m, 'contextLength', 0)}"
|
||||
)
|
||||
|
||||
async def run(self) -> None:
|
||||
# Scenarios reflecting workflows/
|
||||
scenarios: List[Tuple[str, str, AiCallOptions]] = []
|
||||
|
||||
# Task planning (taskPlanner, modeActionplan)
|
||||
scenarios.append(
|
||||
(
|
||||
"PLAN - Quality, Detailed",
|
||||
"Task planning for a multi-step business workflow.",
|
||||
AiCallOptions(
|
||||
operationType=OperationTypeEnum.PLAN,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
compressPrompt=False,
|
||||
compressContext=False,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
maxCost=0.10,
|
||||
maxProcessingTime=30,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Result validation / analysis (modeActionplan)
|
||||
scenarios.append(
|
||||
(
|
||||
"ANALYSE - Balanced, Advanced",
|
||||
"Validate action plan correctness and completeness.",
|
||||
AiCallOptions(
|
||||
operationType=OperationTypeEnum.ANALYSE,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
compressPrompt=True,
|
||||
compressContext=False,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
maxCost=0.05,
|
||||
maxProcessingTime=30,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# React mode - action selection (modeReact)
|
||||
scenarios.append(
|
||||
(
|
||||
"GENERAL - Balanced, Advanced (React: action selection)",
|
||||
"Select next best action from context and state.",
|
||||
AiCallOptions(
|
||||
operationType=OperationTypeEnum.GENERAL,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
compressPrompt=True,
|
||||
compressContext=True,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
maxCost=0.03,
|
||||
maxProcessingTime=20,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# React mode - parameter suggestion (modeReact example)
|
||||
scenarios.append(
|
||||
(
|
||||
"ANALYSE - Balanced, Advanced (React: parameter suggestion)",
|
||||
"Suggest parameters for the selected action as JSON.",
|
||||
AiCallOptions(
|
||||
operationType=OperationTypeEnum.ANALYSE,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
compressPrompt=True,
|
||||
compressContext=False,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
maxCost=0.05,
|
||||
maxProcessingTime=30,
|
||||
resultFormat="json",
|
||||
temperature=0.3,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Iterate and print lists
|
||||
for title, prompt, options in scenarios:
|
||||
await self._printFallbackList(title, prompt, options)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
tester = ModelSelectionTester()
|
||||
await tester.initialize()
|
||||
await tester.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Loading…
Reference in a new issue