1676 lines
No EOL
69 KiB
Python
1676 lines
No EOL
69 KiB
Python
"""
|
|
Datenanalyst-Agent für die Analyse und Interpretation von Daten.
|
|
Angepasst für das refaktorisierte Core-Modul mit AgentCommunicationProtocol.
|
|
"""
|
|
|
|
import logging
|
|
import traceback
|
|
import json
|
|
import re
|
|
import uuid
|
|
import io
|
|
import base64
|
|
from typing import List, Dict, Any, Optional, Union, Tuple
|
|
from datetime import datetime
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
|
|
from modules.agentservice_base import BaseAgent
|
|
from connectors.connector_aichat_openai import ChatService
|
|
from modules.agentservice_utils import WorkflowUtils, MessageUtils, LoggingUtils, FileUtils
|
|
from modules.agentservice_protocol import AgentMessage, AgentCommunicationProtocol
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AnalystAgent(BaseAgent):
|
|
"""Agent for data analysis and interpretation"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the data analyst agent"""
|
|
super().__init__()
|
|
self.id = "analyst_agent"
|
|
self.name = "Data Analyst"
|
|
self.type = "analyst"
|
|
self.description = "Analyzes and interprets data"
|
|
self.capabilities = "data_analysis,pattern_recognition,statistics,visualization,data_interpretation"
|
|
self.result_format = "AnalysisReport"
|
|
|
|
# Initialize AI service
|
|
self.ai_service = None
|
|
|
|
# Document capabilities
|
|
self.supports_documents = True
|
|
self.document_capabilities = ["read", "analyze", "extract"]
|
|
self.required_context = ["data_source", "analysis_objectives"]
|
|
self.document_handler = None
|
|
|
|
# Initialize protocol
|
|
self.protocol = AgentCommunicationProtocol()
|
|
|
|
# Initialize utilities
|
|
self.message_utils = MessageUtils()
|
|
self.file_utils = FileUtils()
|
|
|
|
# Setup visualization defaults
|
|
self.plt_style = 'seaborn-v0_8-whitegrid'
|
|
self.default_figsize = (10, 6)
|
|
self.chart_dpi = 100
|
|
plt.style.use(self.plt_style)
|
|
|
|
def get_agent_info(self) -> Dict[str, Any]:
|
|
"""Get agent information for agent registry"""
|
|
info = super().get_agent_info()
|
|
info.update({
|
|
"metadata": {
|
|
"supported_formats": ["csv", "xlsx", "json", "text"],
|
|
"analysis_types": ["statistical", "trend", "comparative", "predictive", "clustering", "general"],
|
|
"visualization_types": ["bar", "line", "scatter", "histogram", "box", "heatmap", "pie"]
|
|
}
|
|
})
|
|
return info
|
|
|
|
def set_document_handler(self, document_handler):
|
|
"""Set the document handler for file operations"""
|
|
self.document_handler = document_handler
|
|
|
|
async def process_message(self, message: Dict[str, Any], context: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
"""
|
|
Process a message and perform data analysis.
|
|
|
|
Args:
|
|
message: Input message
|
|
context: Optional context
|
|
|
|
Returns:
|
|
Analysis response
|
|
"""
|
|
# Extract workflow_id from context or message
|
|
workflow_id = context.get("workflow_id") if context else message.get("workflow_id", "unknown")
|
|
|
|
# Get or create logging_utils
|
|
log_func = context.get("log_func") if context else None
|
|
logging_utils = LoggingUtils(workflow_id, log_func)
|
|
|
|
# Create status update using protocol
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Starting data analysis",
|
|
sender_id=self.id,
|
|
status="in_progress",
|
|
progress=0.0,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# Create response structure
|
|
response = {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"agent_id": self.id,
|
|
"agent_type": self.type,
|
|
"agent_name": self.name,
|
|
"result_format": self.result_format,
|
|
"workflow_id": workflow_id,
|
|
"documents": []
|
|
}
|
|
|
|
try:
|
|
# Extract task from message
|
|
task = message.get("content", "")
|
|
|
|
# Process any attached documents and extract data
|
|
document_context = ""
|
|
data_frames = {}
|
|
|
|
if message.get("documents"):
|
|
logging_utils.info("Processing documents for analysis", "execution")
|
|
document_context, data_frames = await self._process_and_extract_data(message)
|
|
|
|
# Update progress
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Documents processed, performing analysis",
|
|
sender_id=self.id,
|
|
status="in_progress",
|
|
progress=0.4,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# If we don't have any data frames but expected to analyze data, report this issue
|
|
if not data_frames and any(term in task.lower() for term in ["analyze", "data", "csv", "excel", "file"]):
|
|
if message.get("documents"):
|
|
logging_utils.warning("No processable data found in the provided documents", "execution")
|
|
analysis_content = "## Data Analysis Report\n\nI couldn't find any processable data in the provided documents. Please ensure you've attached CSV, Excel, or other data files in a format I can analyze."
|
|
else:
|
|
logging_utils.warning("No documents provided for data analysis", "execution")
|
|
analysis_content = "## Data Analysis Report\n\nNo data documents were provided for analysis. Please attach CSV, Excel, or other data files for me to analyze."
|
|
|
|
response["content"] = analysis_content
|
|
return response
|
|
|
|
# Determine analysis type and perform analysis
|
|
analysis_type = self._determine_analysis_type(task)
|
|
logging_utils.info(f"Performing {analysis_type} analysis", "execution")
|
|
|
|
# Create enhanced prompt with document context
|
|
enhanced_prompt = self._create_enhanced_prompt(message, document_context, context)
|
|
|
|
# Generate visualization documents if data is available
|
|
visualization_documents = []
|
|
if data_frames:
|
|
logging_utils.info(f"Generating visualizations for {len(data_frames)} data sets", "execution")
|
|
visualization_documents = self._generate_visualizations(data_frames, analysis_type, workflow_id, task)
|
|
|
|
# Add visualizations to response documents
|
|
response["documents"].extend(visualization_documents)
|
|
|
|
# Update progress
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Visualizations created, finalizing analysis",
|
|
sender_id=self.id,
|
|
status="in_progress",
|
|
progress=0.7,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# Generate analysis with included data insights if we have data frames
|
|
analysis_content = ""
|
|
if data_frames:
|
|
# Extract data insights to include in the analysis
|
|
data_insights = self._extract_data_insights(data_frames)
|
|
|
|
# Add insights to the prompt
|
|
enhanced_prompt += f"\n\n=== DATA INSIGHTS ===\n{data_insights}"
|
|
|
|
# Generate analysis with data insights
|
|
analysis_content = await self._generate_analysis(enhanced_prompt, analysis_type)
|
|
|
|
# Include references to the visualization documents
|
|
if visualization_documents:
|
|
viz_references = "\n\n## Visualizations\n\n"
|
|
viz_references += "The following visualizations have been created to help understand the data:\n\n"
|
|
|
|
for i, doc in enumerate(visualization_documents, 1):
|
|
doc_source = doc.get("source", {})
|
|
doc_name = doc_source.get("name", f"Visualization {i}")
|
|
viz_references += f"{i}. **{doc_name}** - Available as an attached document\n"
|
|
|
|
analysis_content += viz_references
|
|
else:
|
|
# Generate analysis based just on text if no data frames
|
|
analysis_content = await self._generate_analysis(enhanced_prompt, analysis_type)
|
|
|
|
# Final progress update
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Analysis completed",
|
|
sender_id=self.id,
|
|
status="completed",
|
|
progress=1.0,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# Set the content in the response
|
|
response["content"] = analysis_content
|
|
|
|
# Finish by sending result message to protocol if needed
|
|
if context and context.get("require_protocol_message"):
|
|
result_message = self.send_analysis_result(
|
|
analysis_content=analysis_content,
|
|
sender_id=self.id,
|
|
receiver_id=context.get("receiver_id", "workflow"),
|
|
task_id=context.get("task_id", f"analysis_{uuid.uuid4()}"),
|
|
analysis_data={
|
|
"analysis_type": analysis_type,
|
|
"visualization_count": len(visualization_documents),
|
|
"data_frame_count": len(data_frames)
|
|
},
|
|
context_id=workflow_id
|
|
)
|
|
# Just log the message creation, don't need to return it
|
|
logging_utils.info(f"Created protocol result message: {result_message.id}", "execution")
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error during data analysis: {str(e)}"
|
|
logging_utils.error(error_msg, "error")
|
|
|
|
# Create error response using protocol
|
|
error_message = self.protocol.create_error_message(
|
|
error_description=error_msg,
|
|
sender_id=self.id,
|
|
error_type="analysis",
|
|
error_details={"traceback": traceback.format_exc()},
|
|
context_id=workflow_id
|
|
)
|
|
|
|
# Set error content in the response
|
|
response["content"] = f"## Error during data analysis\n\n{error_msg}\n\n```\n{traceback.format_exc()}\n```"
|
|
response["status"] = "error"
|
|
|
|
return response
|
|
|
|
async def _process_and_extract_data(self, message: Dict[str, Any]) -> Tuple[str, Dict[str, pd.DataFrame]]:
|
|
"""
|
|
Process documents and extract structured data.
|
|
|
|
Args:
|
|
message: Input message with documents
|
|
|
|
Returns:
|
|
Tuple of (document_context, data_frames_dict)
|
|
"""
|
|
document_context = ""
|
|
data_frames = {}
|
|
|
|
if not message.get("documents"):
|
|
return document_context, data_frames
|
|
|
|
# Extract document text (this will be our context)
|
|
if self.document_handler:
|
|
document_context = self.document_handler.merge_document_contents(message)
|
|
else:
|
|
document_context = self._extract_document_text(message)
|
|
|
|
# Identify and process data files (CSV, Excel, etc.)
|
|
for document in message.get("documents", []):
|
|
source = document.get("source", {})
|
|
filename = source.get("name", "")
|
|
file_id = source.get("id", 0)
|
|
content_type = source.get("content_type", "")
|
|
|
|
# Skip if not a recognizable data file
|
|
if not self._is_data_file(filename, content_type):
|
|
continue
|
|
|
|
try:
|
|
# Try to get file content through document handler first
|
|
file_content = None
|
|
if self.document_handler:
|
|
file_content = self.document_handler.get_file_content_from_message(message, file_id=file_id)
|
|
|
|
# Process based on file type
|
|
if filename.lower().endswith('.csv'):
|
|
df = self._process_csv(file_content, filename)
|
|
if df is not None:
|
|
data_frames[filename] = df
|
|
|
|
elif filename.lower().endswith(('.xlsx', '.xls')):
|
|
dfs = self._process_excel(file_content, filename)
|
|
for sheet_name, df in dfs.items():
|
|
data_frames[f"{filename}::{sheet_name}"] = df
|
|
|
|
elif filename.lower().endswith('.json'):
|
|
df = self._process_json(file_content, filename)
|
|
if df is not None:
|
|
data_frames[filename] = df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing file {filename}: {str(e)}")
|
|
|
|
return document_context, data_frames
|
|
|
|
def _is_data_file(self, filename: str, content_type: str) -> bool:
|
|
"""Check if a file is a processable data file"""
|
|
if filename.lower().endswith(('.csv', '.xlsx', '.xls', '.json')):
|
|
return True
|
|
|
|
if content_type in ['text/csv', 'application/vnd.ms-excel',
|
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
|
'application/json']:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _process_csv(self, file_content: Union[bytes, str], filename: str) -> Optional[pd.DataFrame]:
|
|
"""Process CSV file content into a pandas DataFrame"""
|
|
if file_content is None:
|
|
return None
|
|
|
|
try:
|
|
# Handle the case where file_content is already a string
|
|
if isinstance(file_content, str):
|
|
text_content = file_content
|
|
df = pd.read_csv(io.StringIO(text_content))
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
|
|
# Handle the case where file_content is bytes
|
|
else:
|
|
# Try various encodings
|
|
for encoding in ['utf-8', 'latin1', 'cp1252']:
|
|
try:
|
|
# Use StringIO to create a file-like object
|
|
text_content = file_content.decode(encoding)
|
|
df = pd.read_csv(io.StringIO(text_content))
|
|
|
|
# Basic preprocessing
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
except UnicodeDecodeError:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error processing CSV with {encoding} encoding: {str(e)}")
|
|
|
|
# If all encodings fail, try one more time with errors='replace'
|
|
text_content = file_content.decode('utf-8', errors='replace')
|
|
df = pd.read_csv(io.StringIO(text_content))
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process CSV file {filename}: {str(e)}")
|
|
return None
|
|
|
|
def _process_excel(self, file_content: bytes, filename: str) -> Dict[str, pd.DataFrame]:
|
|
"""Process Excel file content into pandas DataFrames"""
|
|
result = {}
|
|
|
|
if file_content is None:
|
|
return result
|
|
|
|
try:
|
|
# Use BytesIO to create a file-like object
|
|
excel_file = io.BytesIO(file_content)
|
|
|
|
# Try to read with pandas
|
|
excel_data = pd.ExcelFile(excel_file)
|
|
|
|
# Process each sheet
|
|
for sheet_name in excel_data.sheet_names:
|
|
df = pd.read_excel(excel_file, sheet_name=sheet_name)
|
|
|
|
# Basic preprocessing
|
|
df = self._preprocess_dataframe(df)
|
|
|
|
# Only include if there's actual data
|
|
if not df.empty:
|
|
result[sheet_name] = df
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process Excel file {filename}: {str(e)}")
|
|
return result
|
|
|
|
def _process_json(self, file_content: bytes, filename: str) -> Optional[pd.DataFrame]:
|
|
"""Process JSON file content into a pandas DataFrame"""
|
|
if file_content is None:
|
|
return None
|
|
|
|
try:
|
|
# Decode and parse JSON
|
|
json_content = file_content.decode('utf-8')
|
|
data = json.loads(json_content)
|
|
|
|
# Handle different JSON structures
|
|
if isinstance(data, list):
|
|
# List of records
|
|
df = pd.DataFrame(data)
|
|
elif isinstance(data, dict):
|
|
# Try to find a suitable data structure in the dict
|
|
if any(isinstance(v, list) for v in data.values()):
|
|
# Find the first list value to use as data
|
|
for key, value in data.items():
|
|
if isinstance(value, list) and len(value) > 0:
|
|
df = pd.DataFrame(value)
|
|
break
|
|
else:
|
|
# No suitable list found
|
|
return None
|
|
else:
|
|
# Convert flat dict to a single-row DataFrame
|
|
df = pd.DataFrame([data])
|
|
else:
|
|
# Unsupported structure
|
|
return None
|
|
|
|
# Basic preprocessing
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process JSON file {filename}: {str(e)}")
|
|
return None
|
|
|
|
def _preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Perform basic preprocessing on a DataFrame"""
|
|
if df.empty:
|
|
return df
|
|
|
|
# Remove completely empty rows and columns
|
|
df = df.dropna(how='all')
|
|
df = df.dropna(axis=1, how='all')
|
|
|
|
# Try to convert string columns to numeric where appropriate
|
|
for col in df.columns:
|
|
# Skip if already numeric
|
|
if pd.api.types.is_numeric_dtype(df[col]):
|
|
continue
|
|
|
|
# Skip if mostly non-numeric strings
|
|
if df[col].dtype == 'object':
|
|
# Check if more than 80% of non-NA values could be numeric
|
|
non_na_values = df[col].dropna()
|
|
if len(non_na_values) == 0:
|
|
continue
|
|
|
|
# Try to convert to numeric and count successes
|
|
numeric_count = pd.to_numeric(non_na_values, errors='coerce').notna().sum()
|
|
if numeric_count / len(non_na_values) > 0.8:
|
|
# More than 80% can be converted to numeric, so convert the column
|
|
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
|
|
# Try to parse date columns
|
|
for col in df.columns:
|
|
# Skip if not object dtype
|
|
if df[col].dtype != 'object':
|
|
continue
|
|
|
|
# Check if column name suggests a date
|
|
if any(date_term in col.lower() for date_term in ['date', 'time', 'day', 'month', 'year']):
|
|
try:
|
|
# Try to parse as datetime
|
|
df[col] = pd.to_datetime(df[col], errors='coerce')
|
|
# Only keep the conversion if at least 80% succeeded
|
|
if df[col].notna().mean() < 0.8:
|
|
# Revert to original if too many NAs were introduced
|
|
df[col] = df[col].astype('object')
|
|
except:
|
|
pass
|
|
|
|
return df
|
|
|
|
def _extract_document_text(self, message: Dict[str, Any]) -> str:
|
|
"""
|
|
Extract text from documents (fallback method).
|
|
|
|
Args:
|
|
message: Input message with documents
|
|
|
|
Returns:
|
|
Extracted text
|
|
"""
|
|
text_content = ""
|
|
for document in message.get("documents", []):
|
|
source = document.get("source", {})
|
|
name = source.get("name", "unnamed")
|
|
|
|
text_content += f"\n\n--- {name} ---\n"
|
|
|
|
for content in document.get("contents", []):
|
|
if content.get("type") == "text":
|
|
text_content += content.get("text", "")
|
|
|
|
return text_content
|
|
|
|
def _determine_analysis_type(self, task: str) -> str:
|
|
"""
|
|
Determine the type of analysis based on the task.
|
|
|
|
Args:
|
|
task: The analysis task
|
|
|
|
Returns:
|
|
Analysis type
|
|
"""
|
|
task_lower = task.lower()
|
|
|
|
# Check for statistical analysis
|
|
if any(term in task_lower for term in ["statistics", "statistical", "mean", "median", "variance"]):
|
|
return "statistical"
|
|
|
|
# Check for trend analysis
|
|
elif any(term in task_lower for term in ["trend", "pattern", "time series", "historical"]):
|
|
return "trend"
|
|
|
|
# Check for comparative analysis
|
|
elif any(term in task_lower for term in ["compare", "comparison", "versus", "vs", "difference"]):
|
|
return "comparative"
|
|
|
|
# Check for predictive analysis
|
|
elif any(term in task_lower for term in ["predict", "forecast", "future", "projection"]):
|
|
return "predictive"
|
|
|
|
# Check for clustering or categorization
|
|
elif any(term in task_lower for term in ["cluster", "segment", "categorize", "classify"]):
|
|
return "clustering"
|
|
|
|
# Default to general analysis
|
|
else:
|
|
return "general"
|
|
|
|
def _extract_data_insights(self, data_frames: Dict[str, pd.DataFrame]) -> str:
|
|
"""
|
|
Extract basic insights from data frames.
|
|
|
|
Args:
|
|
data_frames: Dictionary of data frames
|
|
|
|
Returns:
|
|
Extracted insights as text
|
|
"""
|
|
insights = []
|
|
|
|
for name, df in data_frames.items():
|
|
if df.empty:
|
|
continue
|
|
|
|
insight = f"Dataset: {name}\n"
|
|
insight += f"Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
|
insight += f"Columns: {', '.join(df.columns.tolist())}\n"
|
|
|
|
# Basic statistics for numeric columns
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) > 0:
|
|
insight += "Numeric column statistics:\n"
|
|
for col in numeric_cols[:5]: # Limit to first 5 columns
|
|
stats = df[col].describe()
|
|
insight += f" {col}: min={stats['min']:.2f}, max={stats['max']:.2f}, mean={stats['mean']:.2f}, median={df[col].median():.2f}\n"
|
|
|
|
if len(numeric_cols) > 5:
|
|
insight += f" ... and {len(numeric_cols) - 5} more numeric columns\n"
|
|
|
|
# Date range for datetime columns
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
if len(date_cols) > 0:
|
|
insight += "Date range:\n"
|
|
for col in date_cols:
|
|
if df[col].notna().any():
|
|
min_date = df[col].min()
|
|
max_date = df[col].max()
|
|
insight += f" {col}: {min_date} to {max_date}\n"
|
|
|
|
# Categorical column value counts
|
|
cat_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
if len(cat_cols) > 0:
|
|
insight += "Categorical columns:\n"
|
|
for col in cat_cols[:3]: # Limit to first 3 columns
|
|
# Get top 3 values
|
|
top_values = df[col].value_counts().head(3)
|
|
vals_str = ", ".join([f"{val} ({count})" for val, count in top_values.items()])
|
|
insight += f" {col}: {df[col].nunique()} unique values. Top values: {vals_str}\n"
|
|
|
|
if len(cat_cols) > 3:
|
|
insight += f" ... and {len(cat_cols) - 3} more categorical columns\n"
|
|
|
|
# Missing values
|
|
missing = df.isna().sum()
|
|
if missing.sum() > 0:
|
|
cols_with_missing = missing[missing > 0]
|
|
insight += "Missing values:\n"
|
|
for col, count in cols_with_missing.items():
|
|
pct = 100 * count / len(df)
|
|
insight += f" {col}: {count} missing values ({pct:.1f}%)\n"
|
|
|
|
insights.append(insight)
|
|
|
|
return "\n\n".join(insights)
|
|
|
|
def _generate_visualizations(self, data_frames: Dict[str, pd.DataFrame], analysis_type: str,
|
|
workflow_id: str, task: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Generate appropriate visualizations based on data and analysis type.
|
|
|
|
Args:
|
|
data_frames: Dictionary of DataFrames to visualize
|
|
analysis_type: Type of analysis being performed
|
|
workflow_id: Workflow ID
|
|
task: Original task description
|
|
|
|
Returns:
|
|
List of visualization document objects
|
|
"""
|
|
documents = []
|
|
|
|
for name, df in data_frames.items():
|
|
if df.empty or df.shape[0] < 2:
|
|
continue # Skip empty or single-row DataFrames
|
|
|
|
# Generate different visualizations based on the analysis type
|
|
if analysis_type == "statistical":
|
|
viz_docs = self._create_statistical_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "trend":
|
|
viz_docs = self._create_trend_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "comparative":
|
|
viz_docs = self._create_comparative_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "predictive":
|
|
viz_docs = self._create_predictive_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "clustering":
|
|
viz_docs = self._create_clustering_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
else: # general analysis
|
|
viz_docs = self._create_general_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
return documents
|
|
|
|
def _create_statistical_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create statistical visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# 1. Distribution/Histogram plots for numeric columns
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:5] # Limit to first 5
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
sns.histplot(df[col].dropna(), kde=True)
|
|
plt.title(f'Distribution of {col}')
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_stat_dist_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Statistical Distributions - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 2. Box plots for numeric columns
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
sns.boxplot(data=df[numeric_cols])
|
|
plt.title(f'Box Plots of Numeric Variables in {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_stat_box_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Box Plots - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 3. Correlation heatmap for numeric columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
corr = df[numeric_cols].corr()
|
|
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
|
|
plt.title(f'Correlation Heatmap - {name}')
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_stat_corr_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Correlation Heatmap - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_trend_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create trend visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# Check for date/time columns
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
|
|
# If we have date columns, create time series plots
|
|
if len(date_cols) > 0:
|
|
date_col = date_cols[0] # Use the first date column
|
|
|
|
# Find numeric columns to plot against the date
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
|
|
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
plt.plot(df[date_col], df[col])
|
|
plt.title(f'Trend of {col} over time')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_trend_time_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Time Series Trends - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# If no date columns found, find another column that might represent sequence/order
|
|
else:
|
|
# Look for columns with sequential numbers
|
|
potential_sequence_cols = []
|
|
for col in df.select_dtypes(include=['number']).columns:
|
|
values = df[col].dropna().values
|
|
if len(values) >= 5:
|
|
# Check if values are mostly sequential
|
|
diffs = np.diff(sorted(values))
|
|
if np.all(diffs > 0) and np.std(diffs) / np.mean(diffs) < 0.5:
|
|
potential_sequence_cols.append(col)
|
|
|
|
# Use first potential sequence column or first numeric column
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(potential_sequence_cols) > 0 and len(numeric_cols) > 1:
|
|
sequence_col = potential_sequence_cols[0]
|
|
# Find other numeric columns to plot against the sequence
|
|
plot_cols = [col for col in numeric_cols if col != sequence_col][:2]
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
for col in plot_cols:
|
|
plt.plot(df[sequence_col], df[col], marker='o', label=col)
|
|
plt.title(f'Trend by {sequence_col} - {name}')
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_trend_seq_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Sequential Trends - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# Moving average visualization if we have enough data points
|
|
if len(df) > 10:
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
for col in numeric_cols:
|
|
# Sort data if we have a date column
|
|
if len(date_cols) > 0:
|
|
sorted_df = df.sort_values(by=date_cols[0])
|
|
else:
|
|
sorted_df = df
|
|
|
|
# Calculate moving average (window size 3)
|
|
values = sorted_df[col].values
|
|
window_size = min(3, len(values) - 1)
|
|
if window_size > 0:
|
|
moving_avg = np.convolve(values, np.ones(window_size)/window_size, mode='valid')
|
|
|
|
# Plot original and moving average
|
|
plt.plot(values, label=f'{col} (Original)')
|
|
plt.plot(np.arange(window_size-1, len(values)), moving_avg, label=f'{col} (Moving Avg)')
|
|
|
|
plt.title(f'Moving Average Trends - {name}')
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_trend_mavg_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Moving Average Trends - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_comparative_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create comparative visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# 1. Look for categorical columns to use for grouping
|
|
cat_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
|
|
if len(cat_cols) > 0:
|
|
# Use the first categorical column with reasonable number of unique values
|
|
groupby_col = None
|
|
for col in cat_cols:
|
|
unique_count = df[col].nunique()
|
|
if 2 <= unique_count <= 10: # Reasonable number of categories
|
|
groupby_col = col
|
|
break
|
|
|
|
if groupby_col:
|
|
# Find numeric columns to compare across groups
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
|
|
|
|
if len(numeric_cols) > 0:
|
|
# 1. Bar chart comparing means
|
|
plt.figure(figsize=(12, 6))
|
|
mean_by_group = df.groupby(groupby_col)[numeric_cols].mean()
|
|
mean_by_group.plot(kind='bar')
|
|
plt.title(f'Mean Comparison by {groupby_col} - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_comp_bar_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Mean Comparison by {groupby_col} - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 2. Box plots for comparing distributions
|
|
plt.figure(figsize=(12, 8))
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
sns.boxplot(x=groupby_col, y=col, data=df)
|
|
plt.title(f'Distribution of {col} by {groupby_col}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_comp_box_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Distribution Comparison - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 3. Scatter plot comparing two numeric variables
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
# Use first two numeric columns
|
|
x_col, y_col = numeric_cols[0], numeric_cols[1]
|
|
|
|
scatter = plt.scatter(df[x_col], df[y_col])
|
|
plt.title(f'Comparison of {x_col} vs {y_col} - {name}')
|
|
plt.xlabel(x_col)
|
|
plt.ylabel(y_col)
|
|
|
|
# Add color if we have a categorical column
|
|
if len(cat_cols) > 0:
|
|
groupby_col = cat_cols[0]
|
|
if df[groupby_col].nunique() <= 10: # Reasonable number of categories
|
|
plt.figure(figsize=(10, 8))
|
|
scatter = plt.scatter(df[x_col], df[y_col], c=pd.factorize(df[groupby_col])[0], cmap='viridis')
|
|
plt.title(f'Comparison of {x_col} vs {y_col} by {groupby_col} - {name}')
|
|
plt.xlabel(x_col)
|
|
plt.ylabel(y_col)
|
|
legend1 = plt.legend(scatter.legend_elements()[0], df[groupby_col].unique(), title=groupby_col)
|
|
plt.gca().add_artist(legend1)
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_comp_scatter_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Variable Comparison - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_predictive_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create predictive visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# Check for date/time columns for time series prediction
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
|
|
if len(date_cols) > 0:
|
|
date_col = date_cols[0] # Use the first date column
|
|
|
|
# Sort by date
|
|
df_sorted = df.sort_values(by=date_col)
|
|
|
|
# Find numeric columns to predict
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
|
|
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
|
|
# Get values and dates
|
|
values = df_sorted[col].values
|
|
dates = df_sorted[date_col].values
|
|
|
|
# Need minimum number of points for meaningful prediction
|
|
if len(values) >= 5:
|
|
# Use basic linear regression for prediction
|
|
# Convert dates to numeric values for regression
|
|
date_nums = np.array([(d - dates[0]).total_seconds() for d in dates])
|
|
date_nums = date_nums / np.max(date_nums) # Normalize
|
|
|
|
# Remove NaNs
|
|
mask = ~np.isnan(values)
|
|
if np.sum(mask) >= 3: # Need at least 3 points
|
|
x = date_nums[mask].reshape(-1, 1)
|
|
y = values[mask]
|
|
|
|
# Fit linear regression
|
|
from sklearn.linear_model import LinearRegression
|
|
model = LinearRegression()
|
|
model.fit(x, y)
|
|
|
|
# Predict on original range
|
|
y_pred = model.predict(x)
|
|
|
|
# Extend for prediction
|
|
x_extended = np.linspace(0, 1.2, 100).reshape(-1, 1)
|
|
y_extended = model.predict(x_extended)
|
|
|
|
# Convert x_extended back to dates for plotting
|
|
max_seconds = np.max([(d - dates[0]).total_seconds() for d in dates])
|
|
future_seconds = x_extended.flatten() * max_seconds
|
|
future_dates = [dates[0] + pd.Timedelta(seconds=s) for s in future_seconds]
|
|
|
|
# Plot
|
|
plt.plot(dates, values, 'o-', label='Actual')
|
|
plt.plot(future_dates, y_extended, '--', label='Predicted')
|
|
plt.axvline(x=dates[-1], color='r', linestyle=':', label='Current')
|
|
|
|
plt.title(f'Prediction for {col}')
|
|
plt.xticks(rotation=45)
|
|
plt.legend()
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_pred_time_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Time Series Prediction - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# Regression prediction (feature vs target)
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
|
|
# Use first two numeric columns as feature and target
|
|
x_col, y_col = numeric_cols[0], numeric_cols[1]
|
|
|
|
# Remove NaNs
|
|
df_clean = df[[x_col, y_col]].dropna()
|
|
|
|
if len(df_clean) >= 5: # Need minimum points for regression
|
|
x = df_clean[x_col].values.reshape(-1, 1)
|
|
y = df_clean[y_col].values
|
|
|
|
# Fit linear regression
|
|
from sklearn.linear_model import LinearRegression
|
|
model = LinearRegression()
|
|
model.fit(x, y)
|
|
|
|
# Generate predictions
|
|
x_range = np.linspace(df_clean[x_col].min(), df_clean[x_col].max() * 1.1, 100).reshape(-1, 1)
|
|
y_pred = model.predict(x_range)
|
|
|
|
# Plot
|
|
plt.scatter(df_clean[x_col], df_clean[y_col], label='Data Points')
|
|
plt.plot(x_range, y_pred, 'r--', label=f'Predicted {y_col}')
|
|
plt.title(f'Regression Prediction: {y_col} based on {x_col} - {name}')
|
|
plt.xlabel(x_col)
|
|
plt.ylabel(y_col)
|
|
plt.legend()
|
|
|
|
# Add regression equation
|
|
slope = model.coef_[0]
|
|
intercept = model.intercept_
|
|
plt.text(0.05, 0.95, f'{y_col} = {slope:.2f} * {x_col} + {intercept:.2f}',
|
|
transform=plt.gca().transAxes, fontsize=10,
|
|
verticalalignment='top')
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_pred_reg_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Regression Prediction - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_clustering_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create clustering visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# Need numeric columns for clustering
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) >= 2:
|
|
# Select two numeric columns for 2D visualization
|
|
cols = numeric_cols[:2]
|
|
|
|
# Remove NaNs
|
|
df_clean = df[cols].dropna()
|
|
|
|
if len(df_clean) >= 5: # Need minimum points for clustering
|
|
# Normalize data
|
|
from sklearn.preprocessing import StandardScaler
|
|
scaler = StandardScaler()
|
|
data_scaled = scaler.fit_transform(df_clean)
|
|
|
|
# Apply K-means clustering
|
|
from sklearn.cluster import KMeans
|
|
# Determine number of clusters (2-5 based on data size)
|
|
n_clusters = min(max(2, len(df_clean) // 10), 5)
|
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
|
clusters = kmeans.fit_predict(data_scaled)
|
|
|
|
# Add cluster labels to DataFrame
|
|
df_clean['Cluster'] = clusters
|
|
|
|
# Create scatter plot with clusters
|
|
plt.figure(figsize=(10, 8))
|
|
|
|
# Plot clusters
|
|
scatter = plt.scatter(df_clean[cols[0]], df_clean[cols[1]], c=df_clean['Cluster'], cmap='viridis')
|
|
|
|
# Plot centroids
|
|
centroids = scaler.inverse_transform(kmeans.cluster_centers_)
|
|
plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=200, c='red', label='Centroids')
|
|
|
|
plt.title(f'K-means Clustering ({n_clusters} clusters) - {name}')
|
|
plt.xlabel(cols[0])
|
|
plt.ylabel(cols[1])
|
|
plt.legend(*scatter.legend_elements(), title="Clusters")
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_clust_kmeans_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"K-means Clustering - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# If we have more than 2 numeric columns, also create a PCA visualization
|
|
if len(numeric_cols) > 2:
|
|
from sklearn.decomposition import PCA
|
|
|
|
# Select more columns for PCA
|
|
pca_cols = numeric_cols[:min(len(numeric_cols), 5)]
|
|
|
|
# Remove NaNs
|
|
df_pca = df[pca_cols].dropna()
|
|
|
|
if len(df_pca) >= 5:
|
|
# Normalize data
|
|
pca_data = StandardScaler().fit_transform(df_pca)
|
|
|
|
# Apply PCA to reduce to 2 dimensions
|
|
pca = PCA(n_components=2)
|
|
principal_components = pca.fit_transform(pca_data)
|
|
|
|
# Create DataFrame with principal components
|
|
pca_df = pd.DataFrame(data=principal_components, columns=['PC1', 'PC2'])
|
|
|
|
# Apply clustering to PCA results
|
|
clusters = KMeans(n_clusters=n_clusters, random_state=42).fit_predict(pca_df)
|
|
pca_df['Cluster'] = clusters
|
|
|
|
# Create scatter plot
|
|
plt.figure(figsize=(10, 8))
|
|
scatter = plt.scatter(pca_df['PC1'], pca_df['PC2'], c=pca_df['Cluster'], cmap='viridis')
|
|
plt.title(f'PCA Clustering ({n_clusters} clusters) - {name}')
|
|
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
|
|
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
|
|
plt.legend(*scatter.legend_elements(), title="Clusters")
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_clust_pca_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"PCA Clustering - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_general_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create general purpose visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# 1. Data overview: numeric summary
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) > 0:
|
|
# Create a bar chart of means for numeric columns
|
|
plt.figure(figsize=(12, 6))
|
|
means = df[numeric_cols].mean().sort_values()
|
|
means.plot(kind='bar')
|
|
plt.title(f'Mean Values of Numeric Variables - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_means_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Numeric Variables Summary - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 2. Categorical data overview
|
|
cat_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
if len(cat_cols) > 0:
|
|
# Select the first categorical column with reasonable cardinality
|
|
for col in cat_cols:
|
|
if df[col].nunique() <= 10: # Reasonable number of categories
|
|
plt.figure(figsize=(10, 6))
|
|
value_counts = df[col].value_counts().sort_values(ascending=False)
|
|
value_counts.plot(kind='bar')
|
|
plt.title(f'Distribution of {col} - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_cat_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Categorical Distribution - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
break # Only use the first suitable column
|
|
|
|
# 3. Correlation matrix if we have multiple numeric columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
corr = df[numeric_cols].corr()
|
|
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
|
|
plt.title(f'Correlation Matrix - {name}')
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_corr_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Correlation Matrix - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 4. If we have date columns, show time-based visualization
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
if len(date_cols) > 0 and len(numeric_cols) > 0:
|
|
date_col = date_cols[0] # Use the first date column
|
|
num_col = numeric_cols[0] # Use the first numeric column
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
plt.plot(df[date_col], df[num_col], marker='o')
|
|
plt.title(f'{num_col} over Time - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_time_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Time Series Overview - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _get_figure_as_base64(self) -> str:
|
|
"""Convert current matplotlib figure to base64 string"""
|
|
buffer = io.BytesIO()
|
|
plt.savefig(buffer, format='png', dpi=self.chart_dpi)
|
|
buffer.seek(0)
|
|
image_png = buffer.getvalue()
|
|
buffer.close()
|
|
|
|
# Convert to base64
|
|
image_base64 = base64.b64encode(image_png).decode('utf-8')
|
|
return image_base64
|
|
|
|
async def _generate_analysis(self, prompt: str, analysis_type: str) -> str:
|
|
"""
|
|
Generate analysis based on prompt and analysis type.
|
|
|
|
Args:
|
|
prompt: The analysis prompt
|
|
analysis_type: Type of analysis
|
|
|
|
Returns:
|
|
Generated analysis
|
|
"""
|
|
if not self.ai_service:
|
|
logging.warning("AI service not available for analysis generation")
|
|
return f"## Data Analysis ({analysis_type})\n\nUnable to generate analysis: AI service not available."
|
|
|
|
# Create specialized prompt based on analysis type
|
|
system_prompt = self._get_analysis_system_prompt(analysis_type)
|
|
|
|
# Enhance the prompt with analysis-specific instructions
|
|
enhanced_prompt = f"""
|
|
Generate a detailed {analysis_type} analysis based on the following:
|
|
|
|
{prompt}
|
|
|
|
Your analysis should include:
|
|
1. A summary of the data
|
|
2. Key findings and insights
|
|
3. Supporting evidence and calculations
|
|
4. Clear conclusions
|
|
5. Recommendations where appropriate
|
|
|
|
Format the analysis in Markdown with proper headings, lists, and tables.
|
|
"""
|
|
|
|
try:
|
|
content = await self.ai_service.call_api([
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": enhanced_prompt}
|
|
])
|
|
|
|
# Ensure there's a title at the top
|
|
if not content.strip().startswith("# "):
|
|
content = f"# {analysis_type.capitalize()} Analysis\n\n{content}"
|
|
|
|
return content
|
|
except Exception as e:
|
|
return f"# {analysis_type.capitalize()} Analysis\n\nError generating analysis: {str(e)}"
|
|
|
|
def _get_analysis_system_prompt(self, analysis_type: str) -> str:
|
|
"""
|
|
Get specialized system prompt for specific analysis type.
|
|
|
|
Args:
|
|
analysis_type: Type of analysis
|
|
|
|
Returns:
|
|
System prompt
|
|
"""
|
|
base_prompt = self._get_system_prompt()
|
|
|
|
# Add analysis-specific instructions
|
|
if analysis_type == "statistical":
|
|
return f"{base_prompt}\n\nFocus on statistical measures including mean, median, mode, variance, and distribution. Identify outliers and unusual data points. Present key statistics in tables where appropriate."
|
|
|
|
elif analysis_type == "trend":
|
|
return f"{base_prompt}\n\nFocus on identifying trends over time, seasonality, and patterns in the data. Look for long-term movements, cyclical patterns, and turning points. Consider rate of change and growth rates."
|
|
|
|
elif analysis_type == "comparative":
|
|
return f"{base_prompt}\n\nFocus on comparing different groups, categories, or time periods. Highlight similarities and differences. Use comparative metrics and relative measures rather than just absolute values."
|
|
|
|
elif analysis_type == "predictive":
|
|
return f"{base_prompt}\n\nFocus on extrapolating trends and patterns to make predictions about future values. Discuss confidence levels and potential factors that could influence outcomes. Be clear about assumptions."
|
|
|
|
elif analysis_type == "clustering":
|
|
return f"{base_prompt}\n\nFocus on identifying natural groupings or segments within the data. Describe the characteristics of each cluster and what distinguishes them. Consider similarities within groups and differences between groups."
|
|
|
|
else:
|
|
return base_prompt
|
|
|
|
def _get_system_prompt(self) -> str:
|
|
"""
|
|
Get specialized system prompt for analyst agent.
|
|
|
|
Returns:
|
|
System prompt
|
|
"""
|
|
return f"""
|
|
You are {self.name}, a specialized {self.type} agent focused on data analysis.
|
|
|
|
{self.description}
|
|
|
|
When analyzing data:
|
|
1. First, identify the data structure and key variables
|
|
2. Look for patterns, trends, and outliers
|
|
3. Provide statistical insights and evidence-based conclusions
|
|
4. Highlight any important findings clearly
|
|
5. Suggest visualizations that would help understand the data
|
|
|
|
For CSV data, interpret tables correctly and perform calculations accurately.
|
|
For textual data, extract key metrics and relationships.
|
|
|
|
Respond in a clear, analytical style, and format your findings in a structured report.
|
|
"""
|
|
|
|
def send_analysis_result(self, analysis_content: str, sender_id: str, receiver_id: str,
|
|
task_id: str, analysis_data: Dict[str, Any] = None,
|
|
context_id: str = None) -> AgentMessage:
|
|
"""
|
|
Send analysis results using the protocol.
|
|
|
|
Args:
|
|
analysis_content: Analysis content
|
|
sender_id: Sender ID
|
|
receiver_id: Receiver ID
|
|
task_id: Task ID
|
|
analysis_data: Additional analysis data
|
|
context_id: Context ID
|
|
|
|
Returns:
|
|
Protocol message
|
|
"""
|
|
return self.protocol.create_result_message(
|
|
result_content=analysis_content,
|
|
sender_id=sender_id,
|
|
receiver_id=receiver_id,
|
|
task_id=task_id,
|
|
output_data=analysis_data,
|
|
result_format=self.result_format,
|
|
context_id=context_id
|
|
)
|
|
|
|
def send_error_message(self, error_description: str, sender_id: str, receiver_id: str = None,
|
|
error_details: Dict[str, Any] = None, context_id: str = None) -> AgentMessage:
|
|
"""
|
|
Send error message using the protocol.
|
|
|
|
Args:
|
|
error_description: Error description
|
|
sender_id: Sender ID
|
|
receiver_id: Receiver ID
|
|
error_details: Error details
|
|
context_id: Context ID
|
|
|
|
Returns:
|
|
Protocol message
|
|
"""
|
|
return self.protocol.create_error_message(
|
|
error_description=error_description,
|
|
sender_id=sender_id,
|
|
receiver_id=receiver_id,
|
|
error_type="analysis_error",
|
|
error_details=error_details,
|
|
context_id=context_id
|
|
)
|
|
|
|
def send_document_request_message(self, document_description: str, sender_id: str, receiver_id: str,
|
|
filters: Dict[str, Any] = None, context_id: str = None) -> AgentMessage:
|
|
"""
|
|
Send document request using the protocol.
|
|
|
|
Args:
|
|
document_description: Document description
|
|
sender_id: Sender ID
|
|
receiver_id: Receiver ID
|
|
filters: Document filters
|
|
context_id: Context ID
|
|
|
|
Returns:
|
|
Protocol message
|
|
"""
|
|
return self.protocol.create_document_request_message(
|
|
document_description=document_description,
|
|
sender_id=sender_id,
|
|
receiver_id=receiver_id,
|
|
filters=filters,
|
|
context_id=context_id
|
|
)
|
|
|
|
# Singleton instance
|
|
_analyst_agent = None
|
|
|
|
def get_analyst_agent():
|
|
"""Returns a singleton instance of the data analyst agent"""
|
|
global _analyst_agent
|
|
if _analyst_agent is None:
|
|
_analyst_agent = AnalystAgent()
|
|
return _analyst_agent |