244 lines
No EOL
8.2 KiB
Python
244 lines
No EOL
8.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Direct Interface Workflow Test Script
|
|
|
|
This script bypasses the API layer and works directly with the interface classes
|
|
to simulate a user uploading two files and then sending a chat request with these files.
|
|
|
|
It follows the state machine as defined in the backend documentation.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import asyncio
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
# Adjust import paths
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
parent_dir = os.path.dirname(current_dir)
|
|
if parent_dir not in sys.path:
|
|
sys.path.insert(0, parent_dir)
|
|
|
|
# Try to import the required modules
|
|
try:
|
|
from modules.workflowManager import getWorkflowManager
|
|
from modules.lucydomInterface import getLucydomInterface
|
|
except ImportError:
|
|
print("Error: Required modules not found. Attempting alternative imports...")
|
|
try:
|
|
from gateway.modules.workflowManager import getWorkflowManager
|
|
from gateway.modules.lucydomInterface import getLucydomInterface
|
|
except ImportError:
|
|
print("Error: Could not import required modules. Make sure the script is run from the correct directory.")
|
|
sys.exit(1)
|
|
|
|
# Constants
|
|
MANDATE_ID = 1
|
|
USER_ID = 1
|
|
#USER_PROMPT = "Please analyze these sales figures and the chart to identify key trends and opportunities."
|
|
#USER_PROMPT = "Please make me a svg file with forecast for Apr-Jun."
|
|
USER_PROMPT = "Please make me a jpg file with forecast for Apr-Jun."
|
|
|
|
# Sample files to upload
|
|
SAMPLE_SVG = """
|
|
<svg width="400" height="300" xmlns="http://www.w3.org/2000/svg">
|
|
<title>Sales Q1 Bar Chart</title>
|
|
<rect width="100%" height="100%" fill="#f9f9f9"/>
|
|
<g transform="translate(50, 20)">
|
|
<!-- Axes -->
|
|
<line x1="0" y1="230" x2="320" y2="230" stroke="black" />
|
|
<line x1="0" y1="0" x2="0" y2="230" stroke="black" />
|
|
|
|
<!-- Y-axis title -->
|
|
<text x="-30" y="120" transform="rotate(-90, -30, 120)">Sales ($)</text>
|
|
|
|
<!-- X-axis title -->
|
|
<text x="160" y="270">Month</text>
|
|
|
|
<!-- January -->
|
|
<rect x="40" y="80" width="60" height="150" fill="#4285F4" />
|
|
<text x="70" y="250">Jan</text>
|
|
<text x="70" y="70">$150K</text>
|
|
|
|
<!-- February -->
|
|
<rect x="130" y="50" width="60" height="180" fill="#EA4335" />
|
|
<text x="160" y="250">Feb</text>
|
|
<text x="160" y="40">$165K</text>
|
|
|
|
<!-- March -->
|
|
<rect x="220" y="20" width="60" height="210" fill="#FBBC05" />
|
|
<text x="250" y="250">Mar</text>
|
|
<text x="250" y="10">$180K</text>
|
|
</g>
|
|
</svg>
|
|
"""
|
|
|
|
SAMPLE_DATA = """
|
|
# Sales Data - Q1 2023
|
|
|
|
Month,Revenue,Growth,Units Sold
|
|
January,150000,5.2%,1250
|
|
February,165000,10.0%,1380
|
|
March,180000,9.1%,1490
|
|
|
|
## Regional Breakdown
|
|
- North: 35% of total sales
|
|
- South: 25% of total sales
|
|
- East: 20% of total sales
|
|
- West: 20% of total sales
|
|
|
|
## Top Products
|
|
1. Product A: 40% of revenue
|
|
2. Product B: 30% of revenue
|
|
3. Product C: 20% of revenue
|
|
4. Others: 10% of revenue
|
|
"""
|
|
|
|
async def create_test_files(mydom):
|
|
"""Create two test files and return their IDs"""
|
|
print("\n--- Uploading Test Files (State 0: File Upload) ---")
|
|
|
|
# Create SVG chart file
|
|
print("Uploading SVG chart file...")
|
|
chart_meta = mydom.saveUploadedFile(SAMPLE_SVG.encode('utf-8'), "q1_sales_chart.svg")
|
|
chart_id = chart_meta['id']
|
|
print(f"Created SVG chart file with ID: {chart_id}")
|
|
|
|
# Create data text file
|
|
print("Uploading markdown data file...")
|
|
data_meta = mydom.saveUploadedFile(SAMPLE_DATA.encode('utf-8'), "q1_sales_data.md")
|
|
data_id = data_meta['id']
|
|
print(f"Created markdown data file with ID: {data_id}")
|
|
|
|
return chart_id, data_id
|
|
|
|
async def monitor_workflow(mydom, workflow_id, timeout=300, interval=2):
|
|
"""Monitor the workflow until it completes or times out"""
|
|
print("\n--- Monitoring Workflow ---")
|
|
start_time = datetime.now()
|
|
elapsed = 0
|
|
|
|
while elapsed < timeout:
|
|
# Get current workflow state
|
|
workflow = mydom.loadWorkflowState(workflow_id)
|
|
if not workflow:
|
|
print("Error: Workflow not found")
|
|
return None
|
|
|
|
status = workflow.get("status", "unknown")
|
|
|
|
# Show progress
|
|
logs = workflow.get("logs", [])
|
|
latest_log = logs[-1] if logs else None
|
|
|
|
if latest_log:
|
|
progress = latest_log.get("progress", 0)
|
|
message = latest_log.get("message", "No message")
|
|
print(f"Status: {status} | Progress: {progress}% | {message}")
|
|
|
|
# Check if workflow is done
|
|
if status in ["completed", "failed", "stopped"]:
|
|
if status == "completed":
|
|
print("\nWorkflow completed successfully!")
|
|
elif status == "failed":
|
|
print("\nWorkflow failed!")
|
|
else:
|
|
print("\nWorkflow was stopped!")
|
|
return workflow
|
|
|
|
# Wait before checking again
|
|
await asyncio.sleep(interval)
|
|
elapsed = (datetime.now() - start_time).total_seconds()
|
|
|
|
print(f"Monitoring timed out after {timeout} seconds")
|
|
return mydom.loadWorkflowState(workflow_id)
|
|
|
|
async def run_test():
|
|
"""Main test function that follows the state machine workflow"""
|
|
print("\n=== Direct Interface Workflow Test ===\n")
|
|
|
|
# Initialize the interfaces
|
|
print("Initializing system...")
|
|
mydom = getLucydomInterface(MANDATE_ID, USER_ID)
|
|
manager = getWorkflowManager(MANDATE_ID, USER_ID)
|
|
|
|
# Upload test files (State 0: File Upload)
|
|
chart_id, data_id = await create_test_files(mydom)
|
|
|
|
# Prepare the user input
|
|
user_input = {
|
|
"prompt": USER_PROMPT,
|
|
"listFileId": [chart_id, data_id]
|
|
}
|
|
|
|
# Start workflow (State 1: Workflow Initialization)
|
|
print(f"\n--- Starting Workflow (State 1: Workflow Initialization) ---")
|
|
print(f"Sending user prompt: '{USER_PROMPT}'")
|
|
print(f"With files: SVG chart (ID: {chart_id}) and sales data (ID: {data_id})")
|
|
|
|
# Start the workflow with the user input
|
|
workflow = await manager.workflowStart(user_input)
|
|
workflow_id = workflow["id"]
|
|
|
|
print(f"Workflow initiated with ID: {workflow_id}")
|
|
print(f"Initial status: {workflow['status']}")
|
|
|
|
# Monitor the workflow progress
|
|
# This will monitor states 2-7 of the state machine
|
|
await monitor_workflow(mydom, workflow_id, timeout=120)
|
|
|
|
# Get final workflow state
|
|
final_workflow = mydom.loadWorkflowState(workflow_id)
|
|
|
|
# Print the results
|
|
print("\n--- Final Workflow Results ---")
|
|
if final_workflow:
|
|
# Print status information
|
|
print(f"Workflow Status: {final_workflow.get('status', 'unknown')}")
|
|
print(f"Current Round: {final_workflow.get('currentRound', 0)}")
|
|
|
|
# Print messages
|
|
print("\n=== Messages ===")
|
|
for msg in final_workflow.get("messages", []):
|
|
role = msg.get("role", "unknown")
|
|
agent = msg.get("agentName", "")
|
|
|
|
# Get a preview of the content
|
|
content = msg.get("content", "")
|
|
if len(content) > 100:
|
|
content_preview = content[:100] + "..."
|
|
else:
|
|
content_preview = content
|
|
|
|
# Format based on role
|
|
if role == "assistant" and agent:
|
|
print(f"\n[{role} - {agent}]: {content_preview}")
|
|
else:
|
|
print(f"\n[{role}]: {content_preview}")
|
|
|
|
# Print document info
|
|
docs = msg.get("documents", [])
|
|
if docs:
|
|
print(f" Documents ({len(docs)}):")
|
|
for doc in docs:
|
|
name = doc.get("name", "unnamed")
|
|
ext = doc.get("ext", "")
|
|
file_id = doc.get("fileId", "unknown")
|
|
print(f" - {name}.{ext} (ID: {file_id})")
|
|
|
|
# Print the final log
|
|
logs = final_workflow.get("logs", [])
|
|
if logs:
|
|
final_log = logs[-1]
|
|
print(f"\nFinal Log: {final_log.get('message', 'No message')}")
|
|
else:
|
|
print("Error: Could not retrieve final workflow state")
|
|
|
|
print("\n=== Test Complete ===")
|
|
return workflow_id
|
|
|
|
if __name__ == "__main__":
|
|
workflow_id = asyncio.run(run_test())
|
|
print(f"Completed workflow ID: {workflow_id}") |