gateway/test_ai_fallback.py
2025-09-02 18:58:30 +02:00

103 lines
3.5 KiB
Python

#!/usr/bin/env python3
"""
Test script to verify AI fallback mechanism from Basic to Advanced when context length is exceeded.
"""
import asyncio
import logging
from modules.interfaces.interfaceAiCalls import AiCalls
from modules.connectors.connectorAiOpenai import ContextLengthExceededException
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_context_length_fallback():
"""Test the fallback mechanism when context length is exceeded"""
# Create AI calls instance
ai_calls = AiCalls()
# Create a very large context that would exceed OpenAI's context limit
large_context = "This is a test context. " * 10000 # Create a large context
prompt = "Please summarize this context in one sentence."
logger.info("Testing AI Basic with large context (should trigger fallback)...")
try:
# This should trigger the context length exceeded error and fallback to Advanced
result = await ai_calls.callAiTextBasic(prompt, large_context)
logger.info(f"✅ Fallback successful! Result: {result[:100]}...")
return True
except Exception as e:
logger.error(f"❌ Test failed: {str(e)}")
return False
async def test_direct_context_length_exception():
"""Test that the ContextLengthExceededException is properly raised"""
from modules.connectors.connectorAiOpenai import AiOpenai
logger.info("Testing direct ContextLengthExceededException...")
try:
# Create OpenAI connector
openai_connector = AiOpenai()
# Create messages that would exceed context length
large_messages = [
{"role": "user", "content": "Test message. " * 50000} # Very large message
]
# This should raise ContextLengthExceededException
await openai_connector.callAiBasic(large_messages)
logger.error("❌ Expected ContextLengthExceededException but none was raised")
return False
except ContextLengthExceededException as e:
logger.info(f"✅ ContextLengthExceededException properly raised: {str(e)}")
return True
except Exception as e:
logger.error(f"❌ Unexpected exception: {str(e)}")
return False
async def main():
"""Run all tests"""
logger.info("Starting AI fallback mechanism tests...")
tests = [
("Context Length Fallback", test_context_length_fallback),
("Direct Exception Test", test_direct_context_length_exception),
]
results = []
for test_name, test_func in tests:
logger.info(f"\n--- Running {test_name} ---")
try:
result = await test_func()
results.append((test_name, result))
except Exception as e:
logger.error(f"Test {test_name} crashed: {str(e)}")
results.append((test_name, False))
# Summary
logger.info("\n" + "="*50)
logger.info("TEST SUMMARY")
logger.info("="*50)
passed = 0
for test_name, result in results:
status = "✅ PASSED" if result else "❌ FAILED"
logger.info(f"{test_name}: {status}")
if result:
passed += 1
logger.info(f"\nTotal: {passed}/{len(results)} tests passed")
if passed == len(results):
logger.info("🎉 All tests passed! Fallback mechanism is working correctly.")
else:
logger.warning("⚠️ Some tests failed. Please check the implementation.")
if __name__ == "__main__":
asyncio.run(main())