289 lines
No EOL
11 KiB
Python
289 lines
No EOL
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for retry enhancement in managerChat.py
|
|
Tests that previous action results and review feedback are properly passed to retry prompts.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import sys
|
|
import os
|
|
|
|
# Add the gateway directory to the Python path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'gateway'))
|
|
|
|
from modules.workflow.managerChat import ChatManager
|
|
from modules.interfaces.interfaceAppModel import User
|
|
from modules.interfaces.interfaceChatModel import ChatWorkflow, ChatMessage
|
|
from modules.interfaces.interfaceChatObjects import ChatObjects
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MockChatObjects(ChatObjects):
|
|
"""Mock implementation of ChatObjects for testing"""
|
|
|
|
def createTaskAction(self, action_data):
|
|
"""Mock task action creation"""
|
|
class MockTaskAction:
|
|
def __init__(self, data):
|
|
self.id = "test_action_id"
|
|
self.execMethod = data.get("execMethod", "unknown")
|
|
self.execAction = data.get("execAction", "unknown")
|
|
self.execParameters = data.get("execParameters", {})
|
|
self.execResultLabel = data.get("execResultLabel", "")
|
|
self.status = data.get("status", "PENDING")
|
|
self.result = ""
|
|
self.error = ""
|
|
|
|
def setSuccess(self):
|
|
self.status = "COMPLETED"
|
|
|
|
def setError(self, error):
|
|
self.status = "FAILED"
|
|
self.error = error
|
|
|
|
def isSuccessful(self):
|
|
return self.status == "COMPLETED"
|
|
|
|
return MockTaskAction(action_data)
|
|
|
|
def createChatDocument(self, document_data):
|
|
"""Mock document creation"""
|
|
class MockChatDocument:
|
|
def __init__(self, data):
|
|
self.fileId = data.get("fileId", "")
|
|
self.filename = data.get("filename", "unknown")
|
|
self.fileSize = data.get("fileSize", 0)
|
|
self.mimeType = data.get("mimeType", "application/octet-stream")
|
|
self.content = ""
|
|
|
|
return MockChatDocument(document_data)
|
|
|
|
def createWorkflowMessage(self, message_data):
|
|
"""Mock message creation"""
|
|
class MockWorkflowMessage:
|
|
def __init__(self, data):
|
|
self.workflowId = data.get("workflowId", "")
|
|
self.role = data.get("role", "assistant")
|
|
self.message = data.get("message", "")
|
|
self.status = data.get("status", "step")
|
|
self.sequenceNr = data.get("sequenceNr", 1)
|
|
self.publishedAt = data.get("publishedAt", "")
|
|
self.actionId = data.get("actionId", "")
|
|
self.actionMethod = data.get("actionMethod", "")
|
|
self.actionName = data.get("actionName", "")
|
|
self.documentsLabel = data.get("documentsLabel", "")
|
|
self.documents = data.get("documents", [])
|
|
|
|
return MockWorkflowMessage(message_data)
|
|
|
|
class MockServiceContainer:
|
|
"""Mock service container for testing"""
|
|
|
|
def __init__(self, user, workflow):
|
|
self.user = user
|
|
self.workflow = workflow
|
|
|
|
def getMethodsList(self):
|
|
"""Mock methods list"""
|
|
return ["document.extract(documentList, aiPrompt)", "document.analyze(documentList, aiPrompt)"]
|
|
|
|
async def summarizeChat(self, messages):
|
|
"""Mock chat summarization"""
|
|
return "Mock chat history summary"
|
|
|
|
def getDocumentReferenceList(self):
|
|
"""Mock document references"""
|
|
return {
|
|
'chat': [],
|
|
'history': []
|
|
}
|
|
|
|
def getConnectionReferenceList(self):
|
|
"""Mock connection references"""
|
|
return ["connection1", "connection2"]
|
|
|
|
def getFileInfo(self, fileId):
|
|
"""Mock file info"""
|
|
return {
|
|
"filename": f"test_file_{fileId}.txt",
|
|
"size": 1024,
|
|
"mimeType": "text/plain"
|
|
}
|
|
|
|
def createFile(self, fileName, mimeType, content, base64encoded=False):
|
|
"""Mock file creation"""
|
|
return f"file_id_{fileName}"
|
|
|
|
def createDocument(self, fileName, mimeType, content, base64encoded=False):
|
|
"""Mock document creation"""
|
|
class MockDocument:
|
|
def __init__(self, name, mime, cont):
|
|
self.filename = name
|
|
self.mimeType = mime
|
|
self.content = cont
|
|
self.fileSize = len(cont)
|
|
|
|
return MockDocument(fileName, mimeType, content)
|
|
|
|
def getFileExtension(self, filename):
|
|
"""Mock file extension extraction"""
|
|
return filename.split('.')[-1] if '.' in filename else 'txt'
|
|
|
|
def getMimeTypeFromExtension(self, extension):
|
|
"""Mock MIME type detection"""
|
|
mime_types = {
|
|
'txt': 'text/plain',
|
|
'pdf': 'application/pdf',
|
|
'doc': 'application/msword',
|
|
'json': 'application/json'
|
|
}
|
|
return mime_types.get(extension, 'application/octet-stream')
|
|
|
|
def detectContentTypeFromData(self, file_bytes, filename):
|
|
"""Mock content type detection"""
|
|
if filename.endswith('.txt'):
|
|
return 'text/plain'
|
|
elif filename.endswith('.pdf'):
|
|
return 'application/pdf'
|
|
elif filename.endswith('.json'):
|
|
return 'application/json'
|
|
return 'application/octet-stream'
|
|
|
|
async def callAiTextBasic(self, prompt):
|
|
"""Mock AI call"""
|
|
return '{"actions": [{"method": "document", "action": "extract", "parameters": {"documentList": ["test"], "aiPrompt": "Test prompt"}, "resultLabel": "task1_action1_test", "description": "Test action"}]}'
|
|
|
|
async def callAiTextAdvanced(self, prompt):
|
|
"""Mock advanced AI call"""
|
|
return '{"overview": "Test plan", "tasks": [{"id": "task_1", "description": "Test task", "dependencies": [], "expected_outputs": ["output1"], "success_criteria": ["criteria1"], "required_documents": [], "estimated_complexity": "low", "ai_prompt": "Test prompt"}]}'
|
|
|
|
async def executeAction(self, methodName, actionName, parameters):
|
|
"""Mock action execution"""
|
|
class MockResult:
|
|
def __init__(self):
|
|
self.success = True
|
|
self.data = {
|
|
"result": "Mock execution result",
|
|
"documents": []
|
|
}
|
|
self.error = None
|
|
|
|
return MockResult()
|
|
|
|
async def test_retry_enhancement():
|
|
"""Test the retry enhancement functionality"""
|
|
logger.info("Testing retry enhancement in managerChat.py")
|
|
|
|
# Create mock objects
|
|
mock_user = User(id="test_user", username="testuser", email="test@example.com", mandateId="test_mandate")
|
|
mock_chat_objects = MockChatObjects()
|
|
mock_workflow = ChatWorkflow(
|
|
id="test_workflow",
|
|
userId="test_user",
|
|
status="active",
|
|
messages=[],
|
|
createdAt="2024-01-01T00:00:00Z",
|
|
updatedAt="2024-01-01T00:00:00Z",
|
|
mandateId="test_mandate",
|
|
currentRound=1,
|
|
lastActivity="2024-01-01T00:00:00Z",
|
|
startedAt="2024-01-01T00:00:00Z"
|
|
)
|
|
|
|
# Create chat manager
|
|
chat_manager = ChatManager(mock_user, mock_chat_objects)
|
|
|
|
# Mock the service container directly instead of initializing
|
|
chat_manager.service = MockServiceContainer(mock_user, mock_workflow)
|
|
chat_manager.workflow = mock_workflow
|
|
|
|
# Test 1: Basic action definition without retry
|
|
logger.info("Test 1: Basic action definition")
|
|
task_step = {
|
|
"id": "task_1",
|
|
"description": "Test task",
|
|
"expected_outputs": ["output1"],
|
|
"success_criteria": ["criteria1"],
|
|
"ai_prompt": "Test AI prompt"
|
|
}
|
|
|
|
actions = await chat_manager.defineTaskActions(task_step, mock_workflow, [])
|
|
logger.info(f"Generated {len(actions)} actions without retry context")
|
|
|
|
# Test 2: Action definition with retry context
|
|
logger.info("Test 2: Action definition with retry context")
|
|
enhanced_context = {
|
|
'task_step': task_step,
|
|
'workflow': mock_workflow,
|
|
'workflow_id': mock_workflow.id,
|
|
'available_documents': ["test_doc.txt"],
|
|
'previous_results': ["task0_action1_results"],
|
|
'improvements': "Previous attempt failed - ensure comprehensive extraction",
|
|
'retry_count': 1,
|
|
'previous_action_results': [
|
|
{
|
|
'actionMethod': 'document',
|
|
'actionName': 'extract',
|
|
'status': 'failed',
|
|
'error': 'Empty result returned',
|
|
'result': 'No content extracted',
|
|
'resultLabel': 'task1_action1_failed'
|
|
}
|
|
],
|
|
'previous_review_result': {
|
|
'status': 'retry',
|
|
'reason': 'Incomplete extraction',
|
|
'quality_score': 3,
|
|
'missing_outputs': ['detailed_analysis'],
|
|
'unmet_criteria': ['comprehensive_coverage']
|
|
}
|
|
}
|
|
|
|
retry_actions = await chat_manager.defineTaskActions(task_step, mock_workflow, [], enhanced_context)
|
|
logger.info(f"Generated {len(retry_actions)} actions with retry context")
|
|
|
|
# Test 3: Verify retry context is properly handled
|
|
logger.info("Test 3: Verifying retry context handling")
|
|
|
|
# Create a test prompt to see if retry context is included
|
|
test_prompt = await chat_manager._createActionDefinitionPrompt(enhanced_context)
|
|
|
|
# Check if retry context is in the prompt
|
|
if "RETRY CONTEXT" in test_prompt:
|
|
logger.info("✓ Retry context properly included in prompt")
|
|
else:
|
|
logger.error("✗ Retry context not found in prompt")
|
|
|
|
if "Previous action results that failed" in test_prompt:
|
|
logger.info("✓ Previous action results included in prompt")
|
|
else:
|
|
logger.error("✗ Previous action results not found in prompt")
|
|
|
|
if "Previous review feedback" in test_prompt:
|
|
logger.info("✓ Previous review feedback included in prompt")
|
|
else:
|
|
logger.error("✗ Previous review feedback not found in prompt")
|
|
|
|
if "Previous attempt failed" in test_prompt:
|
|
logger.info("✓ Improvements needed included in prompt")
|
|
else:
|
|
logger.error("✗ Improvements needed not found in prompt")
|
|
|
|
# Test 4: Verify fallback actions with retry context
|
|
logger.info("Test 4: Testing fallback actions with retry context")
|
|
fallback_actions = chat_manager._createFallbackActions(task_step, enhanced_context)
|
|
logger.info(f"Generated {len(fallback_actions)} fallback actions with retry context")
|
|
|
|
# Check if fallback actions include retry information
|
|
if any("retry" in action.get("resultLabel", "") for action in fallback_actions):
|
|
logger.info("✓ Fallback actions include retry information")
|
|
else:
|
|
logger.error("✗ Fallback actions missing retry information")
|
|
|
|
logger.info("Retry enhancement test completed successfully!")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_retry_enhancement()) |