1804 lines
No EOL
75 KiB
Python
1804 lines
No EOL
75 KiB
Python
"""
|
|
Datenanalyst-Agent für die Analyse und Interpretation von Daten.
|
|
Angepasst für das refaktorisierte Core-Modul mit AgentCommunicationProtocol.
|
|
"""
|
|
|
|
import logging
|
|
import traceback
|
|
import json
|
|
import re
|
|
import uuid
|
|
import io
|
|
import base64
|
|
from typing import List, Dict, Any, Optional, Union, Tuple
|
|
from datetime import datetime
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
|
|
from modules.agentservice_base import BaseAgent
|
|
from connectors.connector_aichat_openai import ChatService
|
|
from modules.agentservice_utils import WorkflowUtils, MessageUtils, LoggingUtils, FileUtils
|
|
from modules.agentservice_protocol import AgentMessage, AgentCommunicationProtocol
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AnalystAgent(BaseAgent):
|
|
"""Agent for data analysis and interpretation"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the data analyst agent"""
|
|
super().__init__()
|
|
self.id = "analyst_agent"
|
|
self.name = "Data Analyst"
|
|
self.type = "analyst"
|
|
self.description = "Analyzes and interprets data"
|
|
self.capabilities = "data_analysis,pattern_recognition,statistics,visualization,data_interpretation"
|
|
self.result_format = "AnalysisReport"
|
|
|
|
# Initialize AI service
|
|
self.ai_service = None
|
|
|
|
# Document capabilities
|
|
self.supports_documents = True
|
|
self.document_capabilities = ["read", "analyze", "extract"]
|
|
self.required_context = ["data_source", "analysis_objectives"]
|
|
self.document_handler = None
|
|
|
|
# Initialize protocol
|
|
self.protocol = AgentCommunicationProtocol()
|
|
|
|
# Initialize utilities
|
|
self.message_utils = MessageUtils()
|
|
self.file_utils = FileUtils()
|
|
|
|
# Setup visualization defaults
|
|
self.plt_style = 'seaborn-v0_8-whitegrid'
|
|
self.default_figsize = (10, 6)
|
|
self.chart_dpi = 100
|
|
plt.style.use(self.plt_style)
|
|
|
|
def get_agent_info(self) -> Dict[str, Any]:
|
|
"""Get agent information for agent registry"""
|
|
info = super().get_agent_info()
|
|
info.update({
|
|
"metadata": {
|
|
"supported_formats": ["csv", "xlsx", "json", "text"],
|
|
"analysis_types": ["statistical", "trend", "comparative", "predictive", "clustering", "general"],
|
|
"visualization_types": ["bar", "line", "scatter", "histogram", "box", "heatmap", "pie"]
|
|
}
|
|
})
|
|
return info
|
|
|
|
def set_document_handler(self, document_handler):
|
|
"""Set the document handler for file operations"""
|
|
self.document_handler = document_handler
|
|
|
|
"""
|
|
Main updates to the process_message method in AnalystAgent to consider all available content.
|
|
"""
|
|
async def process_message(self, message: Dict[str, Any], context: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
"""
|
|
Process a message and perform data analysis.
|
|
|
|
Args:
|
|
message: Input message
|
|
context: Optional context
|
|
|
|
Returns:
|
|
Analysis response
|
|
"""
|
|
# Extract workflow_id from context or message
|
|
workflow_id = context.get("workflow_id") if context else message.get("workflow_id", "unknown")
|
|
|
|
# Get or create logging_utils
|
|
log_func = context.get("log_func") if context else None
|
|
logging_utils = LoggingUtils(workflow_id, log_func)
|
|
|
|
# Create status update using protocol
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Starting data analysis",
|
|
sender_id=self.id,
|
|
status="in_progress",
|
|
progress=0.0,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# Create response structure
|
|
response = {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"agent_id": self.id,
|
|
"agent_type": self.type,
|
|
"agent_name": self.name,
|
|
"result_format": self.result_format,
|
|
"workflow_id": workflow_id,
|
|
"documents": []
|
|
}
|
|
|
|
try:
|
|
# Extract task from message
|
|
task = message.get("content", "")
|
|
|
|
# Process any attached documents and extract data
|
|
document_context = ""
|
|
data_frames = {}
|
|
|
|
if message.get("documents"):
|
|
logging_utils.info("Processing documents for analysis", "execution")
|
|
document_context, data_frames = await self._process_and_extract_data(message)
|
|
|
|
# Update progress
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Documents processed, performing analysis",
|
|
sender_id=self.id,
|
|
status="in_progress",
|
|
progress=0.4,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# Check if we have either data frames OR a substantial text task to analyze
|
|
# This is the key change - we're considering the task text as analyzable content
|
|
have_analyzable_content = len(data_frames) > 0 or (task and len(task.strip()) > 10)
|
|
|
|
if not have_analyzable_content:
|
|
# Only show warning if really no content is available
|
|
if message.get("documents"):
|
|
logging_utils.warning("No processable data found in the provided documents", "execution")
|
|
analysis_content = "## Data Analysis Report\n\nI couldn't find any processable data in the provided documents. Please ensure you've attached CSV, Excel, or other data files in a format I can analyze."
|
|
else:
|
|
logging_utils.warning("No documents or analyzable content provided for analysis", "execution")
|
|
analysis_content = "## Data Analysis Report\n\nNo data or sufficient text content was provided for analysis. Please provide text for analysis or attach data files for me to analyze."
|
|
|
|
response["content"] = analysis_content
|
|
return response
|
|
|
|
# Determine analysis type and perform analysis
|
|
analysis_type = self._determine_analysis_type(task)
|
|
logging_utils.info(f"Performing {analysis_type} analysis", "execution")
|
|
|
|
# Create enhanced prompt with document context
|
|
enhanced_prompt = self._create_enhanced_prompt(message, document_context, context)
|
|
|
|
# Generate visualization documents if data is available
|
|
visualization_documents = []
|
|
if data_frames:
|
|
logging_utils.info(f"Generating visualizations for {len(data_frames)} data sets", "execution")
|
|
visualization_documents = self._generate_visualizations(data_frames, analysis_type, workflow_id, task)
|
|
|
|
# Add visualizations to response documents
|
|
response["documents"].extend(visualization_documents)
|
|
|
|
# Update progress
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Visualizations created, finalizing analysis",
|
|
sender_id=self.id,
|
|
status="in_progress",
|
|
progress=0.7,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# Generate analysis with included data insights if we have data frames
|
|
analysis_content = ""
|
|
if data_frames:
|
|
# Extract data insights to include in the analysis
|
|
data_insights = self._extract_data_insights(data_frames)
|
|
|
|
# Add insights to the prompt
|
|
enhanced_prompt += f"\n\n=== DATA INSIGHTS ===\n{data_insights}"
|
|
|
|
# Generate analysis with data insights
|
|
analysis_content = await self._generate_analysis(enhanced_prompt, analysis_type)
|
|
|
|
# Include references to the visualization documents
|
|
if visualization_documents:
|
|
viz_references = "\n\n## Visualizations\n\n"
|
|
viz_references += "The following visualizations have been created to help understand the data:\n\n"
|
|
|
|
for i, doc in enumerate(visualization_documents, 1):
|
|
doc_source = doc.get("source", {})
|
|
doc_name = doc_source.get("name", f"Visualization {i}")
|
|
viz_references += f"{i}. **{doc_name}** - Available as an attached document\n"
|
|
|
|
analysis_content += viz_references
|
|
else:
|
|
# Generate analysis based just on text if no data frames but we have text to analyze
|
|
# This is the key change - we're analyzing the text content directly
|
|
logging_utils.info("No data frames available, analyzing text content", "execution")
|
|
analysis_content = await self._generate_analysis(enhanced_prompt, analysis_type)
|
|
|
|
# Final progress update
|
|
if log_func:
|
|
status_message = self.protocol.create_status_update_message(
|
|
status_description="Analysis completed",
|
|
sender_id=self.id,
|
|
status="completed",
|
|
progress=1.0,
|
|
context_id=workflow_id
|
|
)
|
|
log_func(workflow_id, status_message.content, "info", self.id, self.name)
|
|
|
|
# Set the content in the response
|
|
response["content"] = analysis_content
|
|
|
|
# Finish by sending result message to protocol if needed
|
|
if context and context.get("require_protocol_message"):
|
|
result_message = self.send_analysis_result(
|
|
analysis_content=analysis_content,
|
|
sender_id=self.id,
|
|
receiver_id=context.get("receiver_id", "workflow"),
|
|
task_id=context.get("task_id", f"analysis_{uuid.uuid4()}"),
|
|
analysis_data={
|
|
"analysis_type": analysis_type,
|
|
"visualization_count": len(visualization_documents),
|
|
"data_frame_count": len(data_frames)
|
|
},
|
|
context_id=workflow_id
|
|
)
|
|
# Just log the message creation, don't need to return it
|
|
logging_utils.info(f"Created protocol result message: {result_message.id}", "execution")
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error during data analysis: {str(e)}"
|
|
logging_utils.error(error_msg, "error")
|
|
|
|
# Create error response using protocol
|
|
error_message = self.protocol.create_error_message(
|
|
error_description=error_msg,
|
|
sender_id=self.id,
|
|
error_type="analysis",
|
|
error_details={"traceback": traceback.format_exc()},
|
|
context_id=workflow_id
|
|
)
|
|
|
|
# Set error content in the response
|
|
response["content"] = f"## Error during data analysis\n\n{error_msg}\n\n```\n{traceback.format_exc()}\n```"
|
|
response["status"] = "error"
|
|
|
|
return response
|
|
|
|
|
|
|
|
"""
|
|
Add _create_enhanced_prompt method to better handle text content in analysis.
|
|
"""
|
|
|
|
def _create_enhanced_prompt(self, message: Dict[str, Any], document_context: str, context: Dict[str, Any] = None) -> str:
|
|
"""
|
|
Create an enhanced prompt for analysis that integrates all available content.
|
|
|
|
Args:
|
|
message: The original message
|
|
document_context: Context extracted from documents
|
|
context: Optional additional context
|
|
|
|
Returns:
|
|
Enhanced prompt for analysis
|
|
"""
|
|
# Get original task/prompt
|
|
task = message.get("content", "")
|
|
|
|
# Add context information if available
|
|
context_info = ""
|
|
if context:
|
|
# Add any dependency outputs from previous activities
|
|
if "dependency_outputs" in context:
|
|
dependency_context = context.get("dependency_outputs", {})
|
|
for name, value in dependency_context.items():
|
|
if isinstance(value, dict) and "content" in value:
|
|
context_info += f"\n\n=== INPUT FROM {name.upper()} ===\n{value['content']}"
|
|
else:
|
|
context_info += f"\n\n=== INPUT FROM {name.upper()} ===\n{str(value)}"
|
|
|
|
# Add expected format information
|
|
if "expected_format" in context:
|
|
context_info += f"\n\nExpected output format: {context.get('expected_format')}"
|
|
|
|
# Start with task
|
|
enhanced_prompt = f"ANALYSIS TASK:\n{task}"
|
|
|
|
# Add any context information
|
|
if context_info:
|
|
enhanced_prompt += f"\n\n{context_info}"
|
|
|
|
# Add document context if available
|
|
if document_context:
|
|
enhanced_prompt += f"\n\n=== DOCUMENT CONTENT ===\n{document_context}"
|
|
else:
|
|
# If no document content, explicitly note that we're analyzing the text content directly
|
|
enhanced_prompt += "\n\nNo data files were provided. Perform analysis on the text content itself."
|
|
|
|
# Add final instructions
|
|
if document_context:
|
|
enhanced_prompt += "\n\nBased on the data and documents provided, please perform a comprehensive analysis."
|
|
else:
|
|
enhanced_prompt += "\n\nBased on the text content provided, please perform a comprehensive analysis."
|
|
|
|
if task:
|
|
enhanced_prompt += f" Focus specifically on addressing: {task}"
|
|
|
|
enhanced_prompt += "\n\nProvide insights, patterns, and conclusions in a clear, structured format."
|
|
|
|
return enhanced_prompt
|
|
|
|
|
|
|
|
async def _process_and_extract_data(self, message: Dict[str, Any]) -> Tuple[str, Dict[str, pd.DataFrame]]:
|
|
"""
|
|
Process documents and extract structured data.
|
|
|
|
Args:
|
|
message: Input message with documents
|
|
|
|
Returns:
|
|
Tuple of (document_context, data_frames_dict)
|
|
"""
|
|
document_context = ""
|
|
data_frames = {}
|
|
|
|
if not message.get("documents"):
|
|
return document_context, data_frames
|
|
|
|
# Extract document text (this will be our context)
|
|
if self.document_handler:
|
|
document_context = self.document_handler.merge_document_contents(message)
|
|
else:
|
|
document_context = self._extract_document_text(message)
|
|
|
|
# Identify and process data files (CSV, Excel, etc.)
|
|
for document in message.get("documents", []):
|
|
source = document.get("source", {})
|
|
filename = source.get("name", "")
|
|
file_id = source.get("id", 0)
|
|
content_type = source.get("content_type", "")
|
|
|
|
# Skip if not a recognizable data file
|
|
if not self._is_data_file(filename, content_type):
|
|
continue
|
|
|
|
try:
|
|
# Try to get file content through document handler first
|
|
file_content = None
|
|
if self.document_handler:
|
|
file_content = self.document_handler.get_file_content_from_message(message, file_id=file_id)
|
|
|
|
# Process based on file type
|
|
if filename.lower().endswith('.csv'):
|
|
df = self._process_csv(file_content, filename)
|
|
if df is not None:
|
|
data_frames[filename] = df
|
|
|
|
elif filename.lower().endswith(('.xlsx', '.xls')):
|
|
dfs = self._process_excel(file_content, filename)
|
|
for sheet_name, df in dfs.items():
|
|
data_frames[f"{filename}::{sheet_name}"] = df
|
|
|
|
elif filename.lower().endswith('.json'):
|
|
df = self._process_json(file_content, filename)
|
|
if df is not None:
|
|
data_frames[filename] = df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing file {filename}: {str(e)}")
|
|
|
|
return document_context, data_frames
|
|
|
|
def _is_data_file(self, filename: str, content_type: str) -> bool:
|
|
"""Check if a file is a processable data file"""
|
|
if filename.lower().endswith(('.csv', '.xlsx', '.xls', '.json')):
|
|
return True
|
|
|
|
if content_type in ['text/csv', 'application/vnd.ms-excel',
|
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
|
'application/json']:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _process_csv(self, file_content: Union[bytes, str], filename: str) -> Optional[pd.DataFrame]:
|
|
"""Process CSV file content into a pandas DataFrame"""
|
|
if file_content is None:
|
|
return None
|
|
|
|
try:
|
|
# Handle the case where file_content is already a string
|
|
if isinstance(file_content, str):
|
|
text_content = file_content
|
|
df = pd.read_csv(io.StringIO(text_content))
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
|
|
# Handle the case where file_content is bytes
|
|
else:
|
|
# Try various encodings
|
|
for encoding in ['utf-8', 'latin1', 'cp1252']:
|
|
try:
|
|
# Use StringIO to create a file-like object
|
|
text_content = file_content.decode(encoding)
|
|
df = pd.read_csv(io.StringIO(text_content))
|
|
|
|
# Basic preprocessing
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
except UnicodeDecodeError:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error processing CSV with {encoding} encoding: {str(e)}")
|
|
|
|
# If all encodings fail, try one more time with errors='replace'
|
|
text_content = file_content.decode('utf-8', errors='replace')
|
|
df = pd.read_csv(io.StringIO(text_content))
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process CSV file {filename}: {str(e)}")
|
|
return None
|
|
|
|
def _process_excel(self, file_content: bytes, filename: str) -> Dict[str, pd.DataFrame]:
|
|
"""Process Excel file content into pandas DataFrames"""
|
|
result = {}
|
|
|
|
if file_content is None:
|
|
return result
|
|
|
|
try:
|
|
# Use BytesIO to create a file-like object
|
|
excel_file = io.BytesIO(file_content)
|
|
|
|
# Try to read with pandas
|
|
excel_data = pd.ExcelFile(excel_file)
|
|
|
|
# Process each sheet
|
|
for sheet_name in excel_data.sheet_names:
|
|
df = pd.read_excel(excel_file, sheet_name=sheet_name)
|
|
|
|
# Basic preprocessing
|
|
df = self._preprocess_dataframe(df)
|
|
|
|
# Only include if there's actual data
|
|
if not df.empty:
|
|
result[sheet_name] = df
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process Excel file {filename}: {str(e)}")
|
|
return result
|
|
|
|
def _process_json(self, file_content: bytes, filename: str) -> Optional[pd.DataFrame]:
|
|
"""Process JSON file content into a pandas DataFrame"""
|
|
if file_content is None:
|
|
return None
|
|
|
|
try:
|
|
# Decode and parse JSON
|
|
json_content = file_content.decode('utf-8')
|
|
data = json.loads(json_content)
|
|
|
|
# Handle different JSON structures
|
|
if isinstance(data, list):
|
|
# List of records
|
|
df = pd.DataFrame(data)
|
|
elif isinstance(data, dict):
|
|
# Try to find a suitable data structure in the dict
|
|
if any(isinstance(v, list) for v in data.values()):
|
|
# Find the first list value to use as data
|
|
for key, value in data.items():
|
|
if isinstance(value, list) and len(value) > 0:
|
|
df = pd.DataFrame(value)
|
|
break
|
|
else:
|
|
# No suitable list found
|
|
return None
|
|
else:
|
|
# Convert flat dict to a single-row DataFrame
|
|
df = pd.DataFrame([data])
|
|
else:
|
|
# Unsupported structure
|
|
return None
|
|
|
|
# Basic preprocessing
|
|
df = self._preprocess_dataframe(df)
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process JSON file {filename}: {str(e)}")
|
|
return None
|
|
|
|
def _preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Perform basic preprocessing on a DataFrame"""
|
|
if df.empty:
|
|
return df
|
|
|
|
# Remove completely empty rows and columns
|
|
df = df.dropna(how='all')
|
|
df = df.dropna(axis=1, how='all')
|
|
|
|
# Try to convert string columns to numeric where appropriate
|
|
for col in df.columns:
|
|
# Skip if already numeric
|
|
if pd.api.types.is_numeric_dtype(df[col]):
|
|
continue
|
|
|
|
# Skip if mostly non-numeric strings
|
|
if df[col].dtype == 'object':
|
|
# Check if more than 80% of non-NA values could be numeric
|
|
non_na_values = df[col].dropna()
|
|
if len(non_na_values) == 0:
|
|
continue
|
|
|
|
# Try to convert to numeric and count successes
|
|
numeric_count = pd.to_numeric(non_na_values, errors='coerce').notna().sum()
|
|
if numeric_count / len(non_na_values) > 0.8:
|
|
# More than 80% can be converted to numeric, so convert the column
|
|
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
|
|
# Try to parse date columns
|
|
for col in df.columns:
|
|
# Skip if not object dtype
|
|
if df[col].dtype != 'object':
|
|
continue
|
|
|
|
# Check if column name suggests a date
|
|
if any(date_term in col.lower() for date_term in ['date', 'time', 'day', 'month', 'year']):
|
|
try:
|
|
# Try to parse as datetime
|
|
df[col] = pd.to_datetime(df[col], errors='coerce')
|
|
# Only keep the conversion if at least 80% succeeded
|
|
if df[col].notna().mean() < 0.8:
|
|
# Revert to original if too many NAs were introduced
|
|
df[col] = df[col].astype('object')
|
|
except:
|
|
pass
|
|
|
|
return df
|
|
|
|
def _extract_document_text(self, message: Dict[str, Any]) -> str:
|
|
"""
|
|
Extract text from documents (fallback method).
|
|
|
|
Args:
|
|
message: Input message with documents
|
|
|
|
Returns:
|
|
Extracted text
|
|
"""
|
|
text_content = ""
|
|
for document in message.get("documents", []):
|
|
source = document.get("source", {})
|
|
name = source.get("name", "unnamed")
|
|
|
|
text_content += f"\n\n--- {name} ---\n"
|
|
|
|
for content in document.get("contents", []):
|
|
if content.get("type") == "text":
|
|
text_content += content.get("text", "")
|
|
|
|
return text_content
|
|
|
|
def _determine_analysis_type(self, task: str) -> str:
|
|
"""
|
|
Determine the type of analysis based on the task.
|
|
Enhanced to better handle text-based analysis.
|
|
|
|
Args:
|
|
task: The analysis task
|
|
|
|
Returns:
|
|
Analysis type
|
|
"""
|
|
task_lower = task.lower()
|
|
|
|
# Check for statistical analysis
|
|
if any(term in task_lower for term in ["statistics", "statistical", "mean", "median", "variance"]):
|
|
return "statistical"
|
|
|
|
# Check for trend analysis
|
|
elif any(term in task_lower for term in ["trend", "pattern", "time series", "historical"]):
|
|
return "trend"
|
|
|
|
# Check for comparative analysis
|
|
elif any(term in task_lower for term in ["compare", "comparison", "versus", "vs", "difference"]):
|
|
return "comparative"
|
|
|
|
# Check for predictive analysis
|
|
elif any(term in task_lower for term in ["predict", "forecast", "future", "projection"]):
|
|
return "predictive"
|
|
|
|
# Check for clustering or categorization
|
|
elif any(term in task_lower for term in ["cluster", "segment", "categorize", "classify"]):
|
|
return "clustering"
|
|
|
|
# Check for text analysis specific terms
|
|
elif any(term in task_lower for term in ["text", "sentiment", "topic", "semantic", "meaning", "interpretation"]):
|
|
return "textual"
|
|
|
|
# Check for summary requests
|
|
elif any(term in task_lower for term in ["summarize", "summary", "overview", "digest"]):
|
|
return "summary"
|
|
|
|
# Default to general analysis
|
|
else:
|
|
return "general"
|
|
|
|
|
|
def _extract_data_insights(self, data_frames: Dict[str, pd.DataFrame]) -> str:
|
|
"""
|
|
Extract basic insights from data frames.
|
|
|
|
Args:
|
|
data_frames: Dictionary of data frames
|
|
|
|
Returns:
|
|
Extracted insights as text
|
|
"""
|
|
insights = []
|
|
|
|
for name, df in data_frames.items():
|
|
if df.empty:
|
|
continue
|
|
|
|
insight = f"Dataset: {name}\n"
|
|
insight += f"Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
|
insight += f"Columns: {', '.join(df.columns.tolist())}\n"
|
|
|
|
# Basic statistics for numeric columns
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) > 0:
|
|
insight += "Numeric column statistics:\n"
|
|
for col in numeric_cols[:5]: # Limit to first 5 columns
|
|
stats = df[col].describe()
|
|
insight += f" {col}: min={stats['min']:.2f}, max={stats['max']:.2f}, mean={stats['mean']:.2f}, median={df[col].median():.2f}\n"
|
|
|
|
if len(numeric_cols) > 5:
|
|
insight += f" ... and {len(numeric_cols) - 5} more numeric columns\n"
|
|
|
|
# Date range for datetime columns
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
if len(date_cols) > 0:
|
|
insight += "Date range:\n"
|
|
for col in date_cols:
|
|
if df[col].notna().any():
|
|
min_date = df[col].min()
|
|
max_date = df[col].max()
|
|
insight += f" {col}: {min_date} to {max_date}\n"
|
|
|
|
# Categorical column value counts
|
|
cat_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
if len(cat_cols) > 0:
|
|
insight += "Categorical columns:\n"
|
|
for col in cat_cols[:3]: # Limit to first 3 columns
|
|
# Get top 3 values
|
|
top_values = df[col].value_counts().head(3)
|
|
vals_str = ", ".join([f"{val} ({count})" for val, count in top_values.items()])
|
|
insight += f" {col}: {df[col].nunique()} unique values. Top values: {vals_str}\n"
|
|
|
|
if len(cat_cols) > 3:
|
|
insight += f" ... and {len(cat_cols) - 3} more categorical columns\n"
|
|
|
|
# Missing values
|
|
missing = df.isna().sum()
|
|
if missing.sum() > 0:
|
|
cols_with_missing = missing[missing > 0]
|
|
insight += "Missing values:\n"
|
|
for col, count in cols_with_missing.items():
|
|
pct = 100 * count / len(df)
|
|
insight += f" {col}: {count} missing values ({pct:.1f}%)\n"
|
|
|
|
insights.append(insight)
|
|
|
|
return "\n\n".join(insights)
|
|
|
|
def _generate_visualizations(self, data_frames: Dict[str, pd.DataFrame], analysis_type: str,
|
|
workflow_id: str, task: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Generate appropriate visualizations based on data and analysis type.
|
|
|
|
Args:
|
|
data_frames: Dictionary of DataFrames to visualize
|
|
analysis_type: Type of analysis being performed
|
|
workflow_id: Workflow ID
|
|
task: Original task description
|
|
|
|
Returns:
|
|
List of visualization document objects
|
|
"""
|
|
documents = []
|
|
|
|
for name, df in data_frames.items():
|
|
if df.empty or df.shape[0] < 2:
|
|
continue # Skip empty or single-row DataFrames
|
|
|
|
# Generate different visualizations based on the analysis type
|
|
if analysis_type == "statistical":
|
|
viz_docs = self._create_statistical_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "trend":
|
|
viz_docs = self._create_trend_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "comparative":
|
|
viz_docs = self._create_comparative_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "predictive":
|
|
viz_docs = self._create_predictive_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
elif analysis_type == "clustering":
|
|
viz_docs = self._create_clustering_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
else: # general analysis
|
|
viz_docs = self._create_general_visualizations(df, name, workflow_id)
|
|
documents.extend(viz_docs)
|
|
|
|
return documents
|
|
|
|
def _create_statistical_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create statistical visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# 1. Distribution/Histogram plots for numeric columns
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:5] # Limit to first 5
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
sns.histplot(df[col].dropna(), kde=True)
|
|
plt.title(f'Distribution of {col}')
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_stat_dist_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Statistical Distributions - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 2. Box plots for numeric columns
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
sns.boxplot(data=df[numeric_cols])
|
|
plt.title(f'Box Plots of Numeric Variables in {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_stat_box_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Box Plots - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 3. Correlation heatmap for numeric columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
corr = df[numeric_cols].corr()
|
|
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
|
|
plt.title(f'Correlation Heatmap - {name}')
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_stat_corr_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Correlation Heatmap - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_trend_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create trend visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# Check for date/time columns
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
|
|
# If we have date columns, create time series plots
|
|
if len(date_cols) > 0:
|
|
date_col = date_cols[0] # Use the first date column
|
|
|
|
# Find numeric columns to plot against the date
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
|
|
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
plt.plot(df[date_col], df[col])
|
|
plt.title(f'Trend of {col} over time')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_trend_time_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Time Series Trends - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# If no date columns found, find another column that might represent sequence/order
|
|
else:
|
|
# Look for columns with sequential numbers
|
|
potential_sequence_cols = []
|
|
for col in df.select_dtypes(include=['number']).columns:
|
|
values = df[col].dropna().values
|
|
if len(values) >= 5:
|
|
# Check if values are mostly sequential
|
|
diffs = np.diff(sorted(values))
|
|
if np.all(diffs > 0) and np.std(diffs) / np.mean(diffs) < 0.5:
|
|
potential_sequence_cols.append(col)
|
|
|
|
# Use first potential sequence column or first numeric column
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(potential_sequence_cols) > 0 and len(numeric_cols) > 1:
|
|
sequence_col = potential_sequence_cols[0]
|
|
# Find other numeric columns to plot against the sequence
|
|
plot_cols = [col for col in numeric_cols if col != sequence_col][:2]
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
for col in plot_cols:
|
|
plt.plot(df[sequence_col], df[col], marker='o', label=col)
|
|
plt.title(f'Trend by {sequence_col} - {name}')
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_trend_seq_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Sequential Trends - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# Moving average visualization if we have enough data points
|
|
if len(df) > 10:
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
for col in numeric_cols:
|
|
# Sort data if we have a date column
|
|
if len(date_cols) > 0:
|
|
sorted_df = df.sort_values(by=date_cols[0])
|
|
else:
|
|
sorted_df = df
|
|
|
|
# Calculate moving average (window size 3)
|
|
values = sorted_df[col].values
|
|
window_size = min(3, len(values) - 1)
|
|
if window_size > 0:
|
|
moving_avg = np.convolve(values, np.ones(window_size)/window_size, mode='valid')
|
|
|
|
# Plot original and moving average
|
|
plt.plot(values, label=f'{col} (Original)')
|
|
plt.plot(np.arange(window_size-1, len(values)), moving_avg, label=f'{col} (Moving Avg)')
|
|
|
|
plt.title(f'Moving Average Trends - {name}')
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_trend_mavg_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Moving Average Trends - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_comparative_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create comparative visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# 1. Look for categorical columns to use for grouping
|
|
cat_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
|
|
if len(cat_cols) > 0:
|
|
# Use the first categorical column with reasonable number of unique values
|
|
groupby_col = None
|
|
for col in cat_cols:
|
|
unique_count = df[col].nunique()
|
|
if 2 <= unique_count <= 10: # Reasonable number of categories
|
|
groupby_col = col
|
|
break
|
|
|
|
if groupby_col:
|
|
# Find numeric columns to compare across groups
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
|
|
|
|
if len(numeric_cols) > 0:
|
|
# 1. Bar chart comparing means
|
|
plt.figure(figsize=(12, 6))
|
|
mean_by_group = df.groupby(groupby_col)[numeric_cols].mean()
|
|
mean_by_group.plot(kind='bar')
|
|
plt.title(f'Mean Comparison by {groupby_col} - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_comp_bar_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Mean Comparison by {groupby_col} - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 2. Box plots for comparing distributions
|
|
plt.figure(figsize=(12, 8))
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
sns.boxplot(x=groupby_col, y=col, data=df)
|
|
plt.title(f'Distribution of {col} by {groupby_col}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_comp_box_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Distribution Comparison - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 3. Scatter plot comparing two numeric variables
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
# Use first two numeric columns
|
|
x_col, y_col = numeric_cols[0], numeric_cols[1]
|
|
|
|
scatter = plt.scatter(df[x_col], df[y_col])
|
|
plt.title(f'Comparison of {x_col} vs {y_col} - {name}')
|
|
plt.xlabel(x_col)
|
|
plt.ylabel(y_col)
|
|
|
|
# Add color if we have a categorical column
|
|
if len(cat_cols) > 0:
|
|
groupby_col = cat_cols[0]
|
|
if df[groupby_col].nunique() <= 10: # Reasonable number of categories
|
|
plt.figure(figsize=(10, 8))
|
|
scatter = plt.scatter(df[x_col], df[y_col], c=pd.factorize(df[groupby_col])[0], cmap='viridis')
|
|
plt.title(f'Comparison of {x_col} vs {y_col} by {groupby_col} - {name}')
|
|
plt.xlabel(x_col)
|
|
plt.ylabel(y_col)
|
|
legend1 = plt.legend(scatter.legend_elements()[0], df[groupby_col].unique(), title=groupby_col)
|
|
plt.gca().add_artist(legend1)
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_comp_scatter_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Variable Comparison - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_predictive_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create predictive visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# Check for date/time columns for time series prediction
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
|
|
if len(date_cols) > 0:
|
|
date_col = date_cols[0] # Use the first date column
|
|
|
|
# Sort by date
|
|
df_sorted = df.sort_values(by=date_col)
|
|
|
|
# Find numeric columns to predict
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
|
|
|
|
if len(numeric_cols) > 0:
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
for i, col in enumerate(numeric_cols, 1):
|
|
plt.subplot(len(numeric_cols), 1, i)
|
|
|
|
# Get values and dates
|
|
values = df_sorted[col].values
|
|
dates = df_sorted[date_col].values
|
|
|
|
# Need minimum number of points for meaningful prediction
|
|
if len(values) >= 5:
|
|
# Use basic linear regression for prediction
|
|
# Convert dates to numeric values for regression
|
|
date_nums = np.array([(d - dates[0]).total_seconds() for d in dates])
|
|
date_nums = date_nums / np.max(date_nums) # Normalize
|
|
|
|
# Remove NaNs
|
|
mask = ~np.isnan(values)
|
|
if np.sum(mask) >= 3: # Need at least 3 points
|
|
x = date_nums[mask].reshape(-1, 1)
|
|
y = values[mask]
|
|
|
|
# Fit linear regression
|
|
from sklearn.linear_model import LinearRegression
|
|
model = LinearRegression()
|
|
model.fit(x, y)
|
|
|
|
# Predict on original range
|
|
y_pred = model.predict(x)
|
|
|
|
# Extend for prediction
|
|
x_extended = np.linspace(0, 1.2, 100).reshape(-1, 1)
|
|
y_extended = model.predict(x_extended)
|
|
|
|
# Convert x_extended back to dates for plotting
|
|
max_seconds = np.max([(d - dates[0]).total_seconds() for d in dates])
|
|
future_seconds = x_extended.flatten() * max_seconds
|
|
future_dates = [dates[0] + pd.Timedelta(seconds=s) for s in future_seconds]
|
|
|
|
# Plot
|
|
plt.plot(dates, values, 'o-', label='Actual')
|
|
plt.plot(future_dates, y_extended, '--', label='Predicted')
|
|
plt.axvline(x=dates[-1], color='r', linestyle=':', label='Current')
|
|
|
|
plt.title(f'Prediction for {col}')
|
|
plt.xticks(rotation=45)
|
|
plt.legend()
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_pred_time_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Time Series Prediction - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# Regression prediction (feature vs target)
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
|
|
# Use first two numeric columns as feature and target
|
|
x_col, y_col = numeric_cols[0], numeric_cols[1]
|
|
|
|
# Remove NaNs
|
|
df_clean = df[[x_col, y_col]].dropna()
|
|
|
|
if len(df_clean) >= 5: # Need minimum points for regression
|
|
x = df_clean[x_col].values.reshape(-1, 1)
|
|
y = df_clean[y_col].values
|
|
|
|
# Fit linear regression
|
|
from sklearn.linear_model import LinearRegression
|
|
model = LinearRegression()
|
|
model.fit(x, y)
|
|
|
|
# Generate predictions
|
|
x_range = np.linspace(df_clean[x_col].min(), df_clean[x_col].max() * 1.1, 100).reshape(-1, 1)
|
|
y_pred = model.predict(x_range)
|
|
|
|
# Plot
|
|
plt.scatter(df_clean[x_col], df_clean[y_col], label='Data Points')
|
|
plt.plot(x_range, y_pred, 'r--', label=f'Predicted {y_col}')
|
|
plt.title(f'Regression Prediction: {y_col} based on {x_col} - {name}')
|
|
plt.xlabel(x_col)
|
|
plt.ylabel(y_col)
|
|
plt.legend()
|
|
|
|
# Add regression equation
|
|
slope = model.coef_[0]
|
|
intercept = model.intercept_
|
|
plt.text(0.05, 0.95, f'{y_col} = {slope:.2f} * {x_col} + {intercept:.2f}',
|
|
transform=plt.gca().transAxes, fontsize=10,
|
|
verticalalignment='top')
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_pred_reg_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Regression Prediction - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_clustering_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create clustering visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# Need numeric columns for clustering
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) >= 2:
|
|
# Select two numeric columns for 2D visualization
|
|
cols = numeric_cols[:2]
|
|
|
|
# Remove NaNs
|
|
df_clean = df[cols].dropna()
|
|
|
|
if len(df_clean) >= 5: # Need minimum points for clustering
|
|
# Normalize data
|
|
from sklearn.preprocessing import StandardScaler
|
|
scaler = StandardScaler()
|
|
data_scaled = scaler.fit_transform(df_clean)
|
|
|
|
# Apply K-means clustering
|
|
from sklearn.cluster import KMeans
|
|
# Determine number of clusters (2-5 based on data size)
|
|
n_clusters = min(max(2, len(df_clean) // 10), 5)
|
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
|
clusters = kmeans.fit_predict(data_scaled)
|
|
|
|
# Add cluster labels to DataFrame
|
|
df_clean['Cluster'] = clusters
|
|
|
|
# Create scatter plot with clusters
|
|
plt.figure(figsize=(10, 8))
|
|
|
|
# Plot clusters
|
|
scatter = plt.scatter(df_clean[cols[0]], df_clean[cols[1]], c=df_clean['Cluster'], cmap='viridis')
|
|
|
|
# Plot centroids
|
|
centroids = scaler.inverse_transform(kmeans.cluster_centers_)
|
|
plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=200, c='red', label='Centroids')
|
|
|
|
plt.title(f'K-means Clustering ({n_clusters} clusters) - {name}')
|
|
plt.xlabel(cols[0])
|
|
plt.ylabel(cols[1])
|
|
plt.legend(*scatter.legend_elements(), title="Clusters")
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_clust_kmeans_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"K-means Clustering - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# If we have more than 2 numeric columns, also create a PCA visualization
|
|
if len(numeric_cols) > 2:
|
|
from sklearn.decomposition import PCA
|
|
|
|
# Select more columns for PCA
|
|
pca_cols = numeric_cols[:min(len(numeric_cols), 5)]
|
|
|
|
# Remove NaNs
|
|
df_pca = df[pca_cols].dropna()
|
|
|
|
if len(df_pca) >= 5:
|
|
# Normalize data
|
|
pca_data = StandardScaler().fit_transform(df_pca)
|
|
|
|
# Apply PCA to reduce to 2 dimensions
|
|
pca = PCA(n_components=2)
|
|
principal_components = pca.fit_transform(pca_data)
|
|
|
|
# Create DataFrame with principal components
|
|
pca_df = pd.DataFrame(data=principal_components, columns=['PC1', 'PC2'])
|
|
|
|
# Apply clustering to PCA results
|
|
clusters = KMeans(n_clusters=n_clusters, random_state=42).fit_predict(pca_df)
|
|
pca_df['Cluster'] = clusters
|
|
|
|
# Create scatter plot
|
|
plt.figure(figsize=(10, 8))
|
|
scatter = plt.scatter(pca_df['PC1'], pca_df['PC2'], c=pca_df['Cluster'], cmap='viridis')
|
|
plt.title(f'PCA Clustering ({n_clusters} clusters) - {name}')
|
|
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
|
|
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
|
|
plt.legend(*scatter.legend_elements(), title="Clusters")
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_clust_pca_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"PCA Clustering - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _create_general_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
|
|
"""Create general purpose visualizations for a DataFrame"""
|
|
documents = []
|
|
|
|
# 1. Data overview: numeric summary
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) > 0:
|
|
# Create a bar chart of means for numeric columns
|
|
plt.figure(figsize=(12, 6))
|
|
means = df[numeric_cols].mean().sort_values()
|
|
means.plot(kind='bar')
|
|
plt.title(f'Mean Values of Numeric Variables - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_means_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Numeric Variables Summary - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 2. Categorical data overview
|
|
cat_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
if len(cat_cols) > 0:
|
|
# Select the first categorical column with reasonable cardinality
|
|
for col in cat_cols:
|
|
if df[col].nunique() <= 10: # Reasonable number of categories
|
|
plt.figure(figsize=(10, 6))
|
|
value_counts = df[col].value_counts().sort_values(ascending=False)
|
|
value_counts.plot(kind='bar')
|
|
plt.title(f'Distribution of {col} - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_cat_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Categorical Distribution - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
break # Only use the first suitable column
|
|
|
|
# 3. Correlation matrix if we have multiple numeric columns
|
|
if len(numeric_cols) >= 2:
|
|
plt.figure(figsize=(10, 8))
|
|
corr = df[numeric_cols].corr()
|
|
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
|
|
plt.title(f'Correlation Matrix - {name}')
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_corr_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Correlation Matrix - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
# 4. If we have date columns, show time-based visualization
|
|
date_cols = df.select_dtypes(include=['datetime']).columns
|
|
if len(date_cols) > 0 and len(numeric_cols) > 0:
|
|
date_col = date_cols[0] # Use the first date column
|
|
num_col = numeric_cols[0] # Use the first numeric column
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
plt.plot(df[date_col], df[num_col], marker='o')
|
|
plt.title(f'{num_col} over Time - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
img_data = self._get_figure_as_base64()
|
|
plt.close()
|
|
|
|
# Create document
|
|
doc_id = f"viz_gen_time_{uuid.uuid4()}"
|
|
doc = {
|
|
"id": doc_id,
|
|
"source": {
|
|
"type": "generated",
|
|
"id": doc_id,
|
|
"name": f"Time Series Overview - {name}",
|
|
"content_type": "image/png",
|
|
"size": len(img_data)
|
|
},
|
|
"contents": [{
|
|
"type": "image",
|
|
"data": img_data,
|
|
"format": "base64"
|
|
}]
|
|
}
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def _get_figure_as_base64(self) -> str:
|
|
"""Convert current matplotlib figure to base64 string"""
|
|
buffer = io.BytesIO()
|
|
plt.savefig(buffer, format='png', dpi=self.chart_dpi)
|
|
buffer.seek(0)
|
|
image_png = buffer.getvalue()
|
|
buffer.close()
|
|
|
|
# Convert to base64
|
|
image_base64 = base64.b64encode(image_png).decode('utf-8')
|
|
return image_base64
|
|
|
|
|
|
"""
|
|
Enhanced _generate_analysis method to better handle text-only analysis.
|
|
"""
|
|
|
|
async def _generate_analysis(self, prompt: str, analysis_type: str) -> str:
|
|
"""
|
|
Generate analysis based on prompt and analysis type.
|
|
Enhanced to handle text-only analysis.
|
|
|
|
Args:
|
|
prompt: The analysis prompt
|
|
analysis_type: Type of analysis
|
|
|
|
Returns:
|
|
Generated analysis
|
|
"""
|
|
if not self.ai_service:
|
|
logging.warning("AI service not available for analysis generation")
|
|
return f"## Data Analysis ({analysis_type})\n\nUnable to generate analysis: AI service not available."
|
|
|
|
# Create specialized prompt based on analysis type
|
|
system_prompt = self._get_analysis_system_prompt(analysis_type)
|
|
|
|
# Determine if this is a data-based or text-based analysis
|
|
is_data_analysis = "DATA INSIGHTS" in prompt
|
|
|
|
# Enhance the prompt with analysis-specific instructions
|
|
if is_data_analysis:
|
|
enhanced_prompt = f"""
|
|
Generate a detailed {analysis_type} analysis based on the following data:
|
|
|
|
{prompt}
|
|
|
|
Your analysis should include:
|
|
1. A summary of the data
|
|
2. Key findings and insights
|
|
3. Supporting evidence and calculations
|
|
4. Clear conclusions
|
|
5. Recommendations where appropriate
|
|
|
|
Format the analysis in Markdown with proper headings, lists, and tables.
|
|
"""
|
|
else:
|
|
# Text-based analysis instructions
|
|
enhanced_prompt = f"""
|
|
Generate a detailed {analysis_type} analysis of the following text content:
|
|
|
|
{prompt}
|
|
|
|
Your analysis should include:
|
|
1. A summary of the main themes and topics
|
|
2. Key insights and observations
|
|
3. Analysis of structure, patterns, and relationships
|
|
4. Clear conclusions and interpretations
|
|
5. Recommendations or implications where appropriate
|
|
|
|
Format the analysis in Markdown with proper headings, lists, and tables.
|
|
"""
|
|
|
|
try:
|
|
content = await self.ai_service.call_api([
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": enhanced_prompt}
|
|
])
|
|
|
|
# Ensure there's a title at the top
|
|
if not content.strip().startswith("# "):
|
|
content = f"# {analysis_type.capitalize()} Analysis\n\n{content}"
|
|
|
|
return content
|
|
except Exception as e:
|
|
return f"# {analysis_type.capitalize()} Analysis\n\nError generating analysis: {str(e)}"
|
|
|
|
|
|
def _get_analysis_system_prompt(self, analysis_type: str) -> str:
|
|
"""
|
|
Get specialized system prompt for specific analysis type.
|
|
Enhanced with text analysis capabilities.
|
|
|
|
Args:
|
|
analysis_type: Type of analysis
|
|
|
|
Returns:
|
|
System prompt
|
|
"""
|
|
base_prompt = self._get_system_prompt()
|
|
|
|
# Add analysis-specific instructions
|
|
if analysis_type == "statistical":
|
|
return f"{base_prompt}\n\nFocus on statistical measures including mean, median, mode, variance, and distribution. Identify outliers and unusual data points. Present key statistics in tables where appropriate."
|
|
|
|
elif analysis_type == "trend":
|
|
return f"{base_prompt}\n\nFocus on identifying trends over time, seasonality, and patterns in the data. Look for long-term movements, cyclical patterns, and turning points. Consider rate of change and growth rates."
|
|
|
|
elif analysis_type == "comparative":
|
|
return f"{base_prompt}\n\nFocus on comparing different groups, categories, or time periods. Highlight similarities and differences. Use comparative metrics and relative measures rather than just absolute values."
|
|
|
|
elif analysis_type == "predictive":
|
|
return f"{base_prompt}\n\nFocus on extrapolating trends and patterns to make predictions about future values. Discuss confidence levels and potential factors that could influence outcomes. Be clear about assumptions."
|
|
|
|
elif analysis_type == "clustering":
|
|
return f"{base_prompt}\n\nFocus on identifying natural groupings or segments within the data. Describe the characteristics of each cluster and what distinguishes them. Consider similarities within groups and differences between groups."
|
|
|
|
elif analysis_type == "textual":
|
|
return f"{base_prompt}\n\nFocus on analyzing the text content provided. Identify key themes, topics, and concepts. Analyze sentiment, tone, and perspective. Extract important relationships, arguments, or logical structures. Provide insights into the meaning and implications of the text."
|
|
|
|
elif analysis_type == "summary":
|
|
return f"{base_prompt}\n\nFocus on providing a concise overview of the provided content. Identify the main points, key arguments, and essential information. Distill complex information into clear, digestible insights. Maintain objectivity while highlighting the most important elements."
|
|
|
|
else:
|
|
return base_prompt
|
|
|
|
|
|
def _get_system_prompt(self) -> str:
|
|
"""
|
|
Get specialized system prompt for analyst agent.
|
|
Enhanced to handle text analysis better.
|
|
|
|
Returns:
|
|
System prompt
|
|
"""
|
|
return f"""
|
|
You are {self.name}, a specialized {self.type} agent focused on data and text analysis.
|
|
|
|
{self.description}
|
|
|
|
When analyzing data:
|
|
1. First, identify the data structure and key variables
|
|
2. Look for patterns, trends, and outliers
|
|
3. Provide statistical insights and evidence-based conclusions
|
|
4. Highlight any important findings clearly
|
|
5. Suggest visualizations that would help understand the data
|
|
|
|
When analyzing text content:
|
|
1. Identify key themes, concepts, and topics
|
|
2. Extract important patterns and relationships
|
|
3. Provide insights into the meaning and implications of the text
|
|
4. Identify sentiment, tone, and perspective where relevant
|
|
5. Organize findings in a logical, structured way
|
|
|
|
For CSV data, interpret tables correctly and perform calculations accurately.
|
|
For textual data, extract key metrics, themes and relationships.
|
|
|
|
Respond in a clear, analytical style, and format your findings in a structured report.
|
|
"""
|
|
|
|
def send_analysis_result(self, analysis_content: str, sender_id: str, receiver_id: str,
|
|
task_id: str, analysis_data: Dict[str, Any] = None,
|
|
context_id: str = None) -> AgentMessage:
|
|
"""
|
|
Send analysis results using the protocol.
|
|
|
|
Args:
|
|
analysis_content: Analysis content
|
|
sender_id: Sender ID
|
|
receiver_id: Receiver ID
|
|
task_id: Task ID
|
|
analysis_data: Additional analysis data
|
|
context_id: Context ID
|
|
|
|
Returns:
|
|
Protocol message
|
|
"""
|
|
return self.protocol.create_result_message(
|
|
result_content=analysis_content,
|
|
sender_id=sender_id,
|
|
receiver_id=receiver_id,
|
|
task_id=task_id,
|
|
output_data=analysis_data,
|
|
result_format=self.result_format,
|
|
context_id=context_id
|
|
)
|
|
|
|
def send_error_message(self, error_description: str, sender_id: str, receiver_id: str = None,
|
|
error_details: Dict[str, Any] = None, context_id: str = None) -> AgentMessage:
|
|
"""
|
|
Send error message using the protocol.
|
|
|
|
Args:
|
|
error_description: Error description
|
|
sender_id: Sender ID
|
|
receiver_id: Receiver ID
|
|
error_details: Error details
|
|
context_id: Context ID
|
|
|
|
Returns:
|
|
Protocol message
|
|
"""
|
|
return self.protocol.create_error_message(
|
|
error_description=error_description,
|
|
sender_id=sender_id,
|
|
receiver_id=receiver_id,
|
|
error_type="analysis_error",
|
|
error_details=error_details,
|
|
context_id=context_id
|
|
)
|
|
|
|
def send_document_request_message(self, document_description: str, sender_id: str, receiver_id: str,
|
|
filters: Dict[str, Any] = None, context_id: str = None) -> AgentMessage:
|
|
"""
|
|
Send document request using the protocol.
|
|
|
|
Args:
|
|
document_description: Document description
|
|
sender_id: Sender ID
|
|
receiver_id: Receiver ID
|
|
filters: Document filters
|
|
context_id: Context ID
|
|
|
|
Returns:
|
|
Protocol message
|
|
"""
|
|
return self.protocol.create_document_request_message(
|
|
document_description=document_description,
|
|
sender_id=sender_id,
|
|
receiver_id=receiver_id,
|
|
filters=filters,
|
|
context_id=context_id
|
|
)
|
|
|
|
# Singleton instance
|
|
_analyst_agent = None
|
|
|
|
def get_analyst_agent():
|
|
"""Returns a singleton instance of the data analyst agent"""
|
|
global _analyst_agent
|
|
if _analyst_agent is None:
|
|
_analyst_agent = AnalystAgent()
|
|
return _analyst_agent |