103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify AI fallback mechanism from Basic to Advanced when context length is exceeded.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from modules.interfaces.interfaceAiCalls import AiCalls
|
|
from modules.connectors.connectorAiOpenai import ContextLengthExceededException
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def test_context_length_fallback():
|
|
"""Test the fallback mechanism when context length is exceeded"""
|
|
|
|
# Create AI calls instance
|
|
ai_calls = AiCalls()
|
|
|
|
# Create a very large context that would exceed OpenAI's context limit
|
|
large_context = "This is a test context. " * 10000 # Create a large context
|
|
prompt = "Please summarize this context in one sentence."
|
|
|
|
logger.info("Testing AI Basic with large context (should trigger fallback)...")
|
|
|
|
try:
|
|
# This should trigger the context length exceeded error and fallback to Advanced
|
|
result = await ai_calls.callAiTextBasic(prompt, large_context)
|
|
logger.info(f"✅ Fallback successful! Result: {result[:100]}...")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Test failed: {str(e)}")
|
|
return False
|
|
|
|
async def test_direct_context_length_exception():
|
|
"""Test that the ContextLengthExceededException is properly raised"""
|
|
|
|
from modules.connectors.connectorAiOpenai import AiOpenai
|
|
|
|
logger.info("Testing direct ContextLengthExceededException...")
|
|
|
|
try:
|
|
# Create OpenAI connector
|
|
openai_connector = AiOpenai()
|
|
|
|
# Create messages that would exceed context length
|
|
large_messages = [
|
|
{"role": "user", "content": "Test message. " * 50000} # Very large message
|
|
]
|
|
|
|
# This should raise ContextLengthExceededException
|
|
await openai_connector.callAiBasic(large_messages)
|
|
logger.error("❌ Expected ContextLengthExceededException but none was raised")
|
|
return False
|
|
|
|
except ContextLengthExceededException as e:
|
|
logger.info(f"✅ ContextLengthExceededException properly raised: {str(e)}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Unexpected exception: {str(e)}")
|
|
return False
|
|
|
|
async def main():
|
|
"""Run all tests"""
|
|
logger.info("Starting AI fallback mechanism tests...")
|
|
|
|
tests = [
|
|
("Context Length Fallback", test_context_length_fallback),
|
|
("Direct Exception Test", test_direct_context_length_exception),
|
|
]
|
|
|
|
results = []
|
|
for test_name, test_func in tests:
|
|
logger.info(f"\n--- Running {test_name} ---")
|
|
try:
|
|
result = await test_func()
|
|
results.append((test_name, result))
|
|
except Exception as e:
|
|
logger.error(f"Test {test_name} crashed: {str(e)}")
|
|
results.append((test_name, False))
|
|
|
|
# Summary
|
|
logger.info("\n" + "="*50)
|
|
logger.info("TEST SUMMARY")
|
|
logger.info("="*50)
|
|
|
|
passed = 0
|
|
for test_name, result in results:
|
|
status = "✅ PASSED" if result else "❌ FAILED"
|
|
logger.info(f"{test_name}: {status}")
|
|
if result:
|
|
passed += 1
|
|
|
|
logger.info(f"\nTotal: {passed}/{len(results)} tests passed")
|
|
|
|
if passed == len(results):
|
|
logger.info("🎉 All tests passed! Fallback mechanism is working correctly.")
|
|
else:
|
|
logger.warning("⚠️ Some tests failed. Please check the implementation.")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|