230 lines
No EOL
7.8 KiB
Python
230 lines
No EOL
7.8 KiB
Python
"""
|
|
Test script for ChatManager workflow with simulated file uploads.
|
|
Demonstrates the complete workflow from file upload to chat execution.
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import os
|
|
import sys
|
|
from typing import Dict, Any, List, Tuple
|
|
from datetime import datetime
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
|
handlers=[logging.StreamHandler()]
|
|
)
|
|
logger = logging.getLogger("test_workflow")
|
|
|
|
# Add project directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
# Import modules
|
|
from modules.lucydom_interface import get_lucydom_interface
|
|
from modules.chat import get_chat_manager
|
|
|
|
async def create_test_files(mandate_id: int, user_id: int) -> Tuple[int, int]:
|
|
"""
|
|
Creates a text file and an image for testing and uploads them to the database.
|
|
|
|
Args:
|
|
mandate_id: ID of the mandate
|
|
user_id: ID of the user
|
|
|
|
Returns:
|
|
Tuple with (text_file_id, image_file_id)
|
|
"""
|
|
logger.info("Creating test files...")
|
|
|
|
lucy_interface = get_lucydom_interface(mandate_id, user_id)
|
|
|
|
# Create text file
|
|
text_content = """
|
|
This is a test text file for the ChatManager workflow.
|
|
It contains some information for testing document processing.
|
|
|
|
The ChatManager should be able to process this file
|
|
and extract relevant information from it.
|
|
|
|
This file serves as an example for text-based documents that can be
|
|
used in a chat workflow.
|
|
"""
|
|
text_file_bytes = text_content.encode('utf-8')
|
|
text_file = lucy_interface.save_uploaded_file(text_file_bytes, "test_document.txt")
|
|
text_file_id = text_file["id"]
|
|
logger.info(f"Text file created with ID: {text_file_id}")
|
|
|
|
# Create a simple test image using PIL
|
|
try:
|
|
from PIL import Image
|
|
import io
|
|
|
|
# Create a 100x100 red image
|
|
img = Image.new('RGB', (100, 100), color = 'red')
|
|
|
|
# Save to BytesIO
|
|
img_bytes = io.BytesIO()
|
|
img.save(img_bytes, format='PNG')
|
|
img_bytes = img_bytes.getvalue()
|
|
|
|
# Upload to database
|
|
image_file = lucy_interface.save_uploaded_file(img_bytes, "test_image.png")
|
|
image_file_id = image_file["id"]
|
|
logger.info(f"Image file created with ID: {image_file_id}")
|
|
|
|
except ImportError:
|
|
# Fallback to the original method if PIL is not available
|
|
png_data = bytes([
|
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, # PNG Header
|
|
# ... rest of your PNG data ...
|
|
])
|
|
|
|
with open("./test_img_orig.png", 'wb') as f:
|
|
f.write(png_data)
|
|
|
|
image_file = lucy_interface.save_uploaded_file(png_data, "test_image.png")
|
|
image_file_id = image_file["id"]
|
|
logger.info(f"Image file created with ID: {image_file_id}")
|
|
|
|
return text_file_id, image_file_id
|
|
|
|
|
|
|
|
async def run_chat_workflow(mandate_id: int, user_id: int, file_ids: List[int]) -> Dict[str, Any]:
|
|
"""
|
|
Executes a chat workflow with given file IDs.
|
|
|
|
Args:
|
|
mandate_id: ID of the mandate
|
|
user_id: ID of the user
|
|
file_ids: List of file IDs
|
|
|
|
Returns:
|
|
The workflow result
|
|
"""
|
|
logger.info(f"Starting chat workflow with files: {file_ids}")
|
|
|
|
# Initialize ChatManager
|
|
chat_manager = get_chat_manager(mandate_id, user_id)
|
|
|
|
# Create user request
|
|
user_input = {
|
|
"prompt": "Bitte zähle mir zusammen wieviele Pixel das Bild hat und wieviele Zeichen der Text der Dokumente hat",
|
|
"list_file_id": file_ids
|
|
}
|
|
|
|
# Execute chat workflow
|
|
workflow_result = await chat_manager.chat_run(user_input)
|
|
logger.info(f"Workflow completed with ID: {workflow_result['id']}")
|
|
|
|
return workflow_result
|
|
|
|
def analyze_workflow_result(workflow: Dict[str, Any]) -> None:
|
|
"""
|
|
Analyzes and outputs information about the workflow result.
|
|
|
|
Args:
|
|
workflow: The workflow result
|
|
"""
|
|
logger.info("Analyzing workflow result:")
|
|
logger.info(f"Workflow ID: {workflow['id']}")
|
|
logger.info(f"Status: {workflow['status']}")
|
|
logger.info(f"Number of messages: {len(workflow.get('messages', []))}")
|
|
|
|
for i, message in enumerate(workflow.get('messages', [])):
|
|
logger.info(f"Message {i+1}:")
|
|
logger.info(f" Role: {message.get('role', 'unknown')}")
|
|
|
|
# Show only the first 100 characters of content
|
|
content = message.get('content', '')
|
|
content_preview = content[:100] + '...' if len(content) > 100 else content
|
|
logger.info(f" Content: {content_preview}")
|
|
|
|
# Show documents in the message
|
|
documents = message.get('documents', [])
|
|
logger.info(f" Documents: {len(documents)}")
|
|
for j, doc in enumerate(documents):
|
|
doc_id = doc.get('id', 'no ID')
|
|
file_id = doc.get('file_id', 'no file_id')
|
|
logger.info(f" Document {j+1}: ID={doc_id}, File-ID={file_id}")
|
|
|
|
# Information about contents
|
|
contents = doc.get('contents', [])
|
|
for k, content in enumerate(contents):
|
|
content_name = content.get('name', 'no name')
|
|
content_type = content.get('content_type', 'unknown')
|
|
logger.info(f" Content {k+1}: {content_name} ({content_type})")
|
|
|
|
logs = workflow.get('logs', [])
|
|
logger.info(f"Logs: {len(logs)}")
|
|
# Get only the first 10 logs
|
|
for i, log in enumerate(logs[:10]): # Apply the slice to logs, not enumerate
|
|
log_type = log.get('type', 'info')
|
|
log_message = log.get('message', '')
|
|
log_message_preview = log_message[:100] + '...' if len(log_message) > 100 else log_message
|
|
logger.info(f" Log {i+1} [{log_type}]: {log_message_preview}")
|
|
|
|
async def cleanup_test_files(mandate_id: int, user_id: int, file_ids: List[int]) -> None:
|
|
"""
|
|
Cleans up the created test files.
|
|
|
|
Args:
|
|
mandate_id: ID of the mandate
|
|
user_id: ID of the user
|
|
file_ids: List of file IDs to delete
|
|
"""
|
|
logger.info("Starting cleanup of test files...")
|
|
|
|
lucy_interface = get_lucydom_interface(mandate_id, user_id)
|
|
|
|
for file_id in file_ids:
|
|
try:
|
|
success = lucy_interface.delete_file(file_id)
|
|
if success:
|
|
logger.info(f"File with ID {file_id} successfully deleted")
|
|
else:
|
|
logger.warning(f"Error deleting file with ID {file_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting file with ID {file_id}: {str(e)}")
|
|
|
|
logger.info("Cleanup completed")
|
|
|
|
async def main():
|
|
"""
|
|
Main function that controls the entire test process.
|
|
"""
|
|
# Test parameters
|
|
MANDATE_ID = 1 # Test mandate ID
|
|
USER_ID = 1 # Test user ID
|
|
CLEANUP = True # Cleanup after test
|
|
|
|
try:
|
|
logger.info("=== ChatManager test workflow started ===")
|
|
|
|
# Step 1: Create test files
|
|
text_file_id, image_file_id = await create_test_files(MANDATE_ID, USER_ID)
|
|
file_ids = [text_file_id, image_file_id]
|
|
|
|
# Step 2: Execute chat workflow
|
|
workflow_result = await run_chat_workflow(MANDATE_ID, USER_ID, file_ids)
|
|
|
|
# Step 3: Analyze result
|
|
analyze_workflow_result(workflow_result)
|
|
|
|
# Step 4: Optional cleanup
|
|
if CLEANUP:
|
|
await cleanup_test_files(MANDATE_ID, USER_ID, file_ids)
|
|
|
|
logger.info("=== Test workflow successfully completed ===")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in test workflow: {str(e)}", exc_info=True)
|
|
logger.info("=== Test workflow ended with error ===")
|
|
|
|
if __name__ == "__main__":
|
|
# Create event loop for asyncio and execute main function
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(main()) |