gateway/test_ai_model_selection.py

163 lines
5.5 KiB
Python

#!/usr/bin/env python3
"""
AI Model Selection Test - Prints prioritized fallback model lists used for AI calls
Scenarios mirror typical calls in workflows/ (task planning, action planning,
analysis, and react-mode decisions), showing which models are shortlisted and
their final prioritized order after rating and cost tie-breaking.
"""
import asyncio
import os
import sys
from typing import List, Tuple
# Ensure gateway is on path when running directly
sys.path.append(os.path.dirname(__file__))
from modules.features.chatPlayground.mainChatPlayground import getServices
from modules.datamodels.datamodelAi import (
AiCallOptions,
OperationTypeEnum,
PriorityEnum,
ProcessingModeEnum,
)
from modules.datamodels.datamodelUam import User
from modules.aicore.aicoreModelRegistry import modelRegistry
from modules.aicore.aicoreModelSelector import model_selector
class ModelSelectionTester:
def __init__(self) -> None:
testUser = User(
id="test_user_models",
username="test_models",
email="test@example.com",
fullName="Test Models",
language="en",
mandateId="test_mandate",
)
self.services = getServices(testUser, None)
async def initialize(self) -> None:
from modules.services.serviceAi.mainServiceAi import AiService
self.services.ai = await AiService.create(self.services)
async def _printFallbackList(self, title: str, prompt: str, options: AiCallOptions) -> None:
print(f"\n{'='*80}")
print(f"{title}")
print(f"{'='*80}")
print(
f"Operation={options.operationType.name}, Priority={options.priority.name}, ProcessingMode={options.processingMode.name}"
)
availableModels = modelRegistry.getAvailableModels()
fallbackModels = model_selector.getFallbackModels(
prompt=prompt,
context="",
options=options,
availableModels=availableModels,
)
if not fallbackModels:
print("No suitable models found (capability filter returned empty list).")
return
print("Prioritized fallback model sequence (name | quality | speed | $/1k in | ctx):")
for idx, m in enumerate(fallbackModels, 1):
costIn = getattr(m, "costPer1kTokensInput", 0.0)
print(
f" {idx:>2}. {m.name} | Q={getattr(m, 'qualityRating', 0)} | S={getattr(m, 'speedRating', 0)} | ${costIn:.4f} | ctx={getattr(m, 'contextLength', 0)}"
)
async def run(self) -> None:
# Scenarios reflecting workflows/
scenarios: List[Tuple[str, str, AiCallOptions]] = []
# Task planning (taskPlanner, modeActionplan)
scenarios.append(
(
"PLAN - Quality, Detailed",
"Task planning for a multi-step business workflow.",
AiCallOptions(
operationType=OperationTypeEnum.PLAN,
priority=PriorityEnum.QUALITY,
compressPrompt=False,
compressContext=False,
processingMode=ProcessingModeEnum.DETAILED,
maxCost=0.10,
maxProcessingTime=30,
),
)
)
# Result validation / analysis (modeActionplan)
scenarios.append(
(
"ANALYSE - Balanced, Advanced",
"Validate action plan correctness and completeness.",
AiCallOptions(
operationType=OperationTypeEnum.ANALYSE,
priority=PriorityEnum.BALANCED,
compressPrompt=True,
compressContext=False,
processingMode=ProcessingModeEnum.ADVANCED,
maxCost=0.05,
maxProcessingTime=30,
),
)
)
# React mode - action selection (modeReact)
scenarios.append(
(
"GENERAL - Balanced, Advanced (React: action selection)",
"Select next best action from context and state.",
AiCallOptions(
operationType=OperationTypeEnum.GENERAL,
priority=PriorityEnum.BALANCED,
compressPrompt=True,
compressContext=True,
processingMode=ProcessingModeEnum.ADVANCED,
maxCost=0.03,
maxProcessingTime=20,
),
)
)
# React mode - parameter suggestion (modeReact example)
scenarios.append(
(
"ANALYSE - Balanced, Advanced (React: parameter suggestion)",
"Suggest parameters for the selected action as JSON.",
AiCallOptions(
operationType=OperationTypeEnum.ANALYSE,
priority=PriorityEnum.BALANCED,
compressPrompt=True,
compressContext=False,
processingMode=ProcessingModeEnum.ADVANCED,
maxCost=0.05,
maxProcessingTime=30,
resultFormat="json",
temperature=0.3,
),
)
)
# Iterate and print lists
for title, prompt, options in scenarios:
await self._printFallbackList(title, prompt, options)
async def main() -> None:
tester = ModelSelectionTester()
await tester.initialize()
await tester.run()
if __name__ == "__main__":
asyncio.run(main())