752 lines
No EOL
30 KiB
Python
752 lines
No EOL
30 KiB
Python
"""
|
|
Data analyst agent for analysis and interpretation of data.
|
|
Optimized for the new task-based processing.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
import re
|
|
import uuid
|
|
import io
|
|
import base64
|
|
from typing import Dict, Any, List, Optional
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
from modules.chat_registry import AgentBase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AgentAnalyst(AgentBase):
|
|
"""Agent for analysis and interpretation of data"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the data analysis agent"""
|
|
super().__init__()
|
|
self.name = "analyst"
|
|
self.description = "Analyzes and interprets data using statistical methods and visualizations"
|
|
self.capabilities = [
|
|
"data_analysis",
|
|
"pattern_recognition",
|
|
"statistics",
|
|
"visualization",
|
|
"data_interpretation"
|
|
]
|
|
|
|
# Visualization settings
|
|
self.plt_style = 'seaborn-v0_8-whitegrid'
|
|
self.default_figsize = (10, 6)
|
|
self.chart_dpi = 100
|
|
plt.style.use(self.plt_style)
|
|
|
|
def set_dependencies(self, ai_service=None):
|
|
"""Set external dependencies for the agent."""
|
|
self.ai_service = ai_service
|
|
|
|
async def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Process a standardized task structure and perform data analysis.
|
|
|
|
Args:
|
|
task: A dictionary containing:
|
|
- task_id: Unique ID for this task
|
|
- prompt: The main instruction for the agent
|
|
- input_documents: List of documents to process
|
|
- output_specifications: List of required output documents
|
|
- context: Additional contextual information
|
|
|
|
Returns:
|
|
A dictionary containing:
|
|
- feedback: Text response explaining the analysis results
|
|
- documents: List of created document objects
|
|
"""
|
|
try:
|
|
# Extract relevant task information
|
|
prompt = task.get("prompt", "")
|
|
input_documents = task.get("input_documents", [])
|
|
output_specs = task.get("output_specifications", [])
|
|
|
|
# Check if AI service is available
|
|
if not self.ai_service:
|
|
logger.error("No AI service configured for the Analyst agent")
|
|
return {
|
|
"feedback": "The Analyst agent is not properly configured.",
|
|
"documents": []
|
|
}
|
|
|
|
# Extract data from input documents
|
|
data_frames, document_context = self._extract_data_from_documents(input_documents)
|
|
|
|
# Check if we have analyzable content
|
|
have_analyzable_content = len(data_frames) > 0 or (prompt and len(prompt.strip()) > 10)
|
|
|
|
if not have_analyzable_content:
|
|
# Warning if no analyzable content available
|
|
logger.warning("No analyzable content found")
|
|
feedback = "I couldn't find any processable data in the provided documents."
|
|
return {
|
|
"feedback": feedback,
|
|
"documents": []
|
|
}
|
|
|
|
# Determine analysis type
|
|
analysis_type = self._determine_analysis_type(prompt)
|
|
logger.info(f"Performing {analysis_type} analysis")
|
|
|
|
# Store generated documents
|
|
generated_documents = []
|
|
|
|
# Extract data insights if DataFrames are available
|
|
data_insights = ""
|
|
if data_frames:
|
|
data_insights = self._extract_data_insights(data_frames)
|
|
logger.info(f"Extracted insights from {len(data_frames)} datasets")
|
|
|
|
# Generate an appropriate document for each requested output
|
|
for spec in output_specs:
|
|
output_label = spec.get("label", "")
|
|
output_description = spec.get("description", "")
|
|
|
|
# Determine format based on file extension
|
|
format_type = self._determine_format_type(output_label)
|
|
|
|
# Special handling for visualizations if required
|
|
if "chart" in output_label.lower() or "plot" in output_label.lower() or "visualization" in output_label.lower() or format_type in ["png", "jpg", "svg"]:
|
|
# Generate visualization document if data available
|
|
if data_frames:
|
|
viz_document = self._generate_visualization_document(data_frames, analysis_type, prompt, output_label)
|
|
generated_documents.append(viz_document)
|
|
else:
|
|
# Fallback if no data
|
|
generated_documents.append({
|
|
"label": output_label,
|
|
"content": "No data available for visualization."
|
|
})
|
|
else:
|
|
# Create text-based analysis
|
|
content = await self._generate_analysis_document(
|
|
prompt,
|
|
document_context,
|
|
data_insights,
|
|
analysis_type,
|
|
format_type,
|
|
output_label,
|
|
output_description
|
|
)
|
|
|
|
generated_documents.append({
|
|
"label": output_label,
|
|
"content": content
|
|
})
|
|
|
|
# If no specific outputs requested, create standard documents
|
|
if not output_specs:
|
|
# Standard analysis
|
|
analysis_content = await self._generate_analysis_document(
|
|
prompt,
|
|
document_context,
|
|
data_insights,
|
|
analysis_type,
|
|
"markdown",
|
|
"analysis_report.md",
|
|
"Analysis report"
|
|
)
|
|
|
|
generated_documents.append({
|
|
"label": "analysis_report.md",
|
|
"content": analysis_content
|
|
})
|
|
|
|
# Add visualization if data available
|
|
if data_frames:
|
|
viz_document = self._generate_visualization_document(data_frames, analysis_type, prompt, "data_visualization.png")
|
|
generated_documents.append(viz_document)
|
|
|
|
# Create feedback
|
|
if data_frames:
|
|
feedback = f"I analyzed {len(data_frames)} datasets and created {len(generated_documents)} documents with the results."
|
|
else:
|
|
feedback = f"I performed a text analysis and created {len(generated_documents)} documents with the results."
|
|
|
|
return {
|
|
"feedback": feedback,
|
|
"documents": generated_documents
|
|
}
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error during data analysis: {str(e)}"
|
|
logger.error(error_msg)
|
|
return {
|
|
"feedback": f"An error occurred during data analysis: {str(e)}",
|
|
"documents": []
|
|
}
|
|
|
|
def _extract_data_from_documents(self, documents: List[Dict[str, Any]]) -> tuple:
|
|
"""
|
|
Extract data from input documents.
|
|
|
|
Args:
|
|
documents: List of input documents
|
|
|
|
Returns:
|
|
Tuple of (Dictionary of DataFrames, Document context text)
|
|
"""
|
|
data_frames = {}
|
|
document_context = ""
|
|
|
|
for doc in documents:
|
|
doc_name = doc.get("name", "unnamed")
|
|
document_context += f"\n\n--- {doc_name} ---\n"
|
|
|
|
for content in doc.get("contents", []):
|
|
# Extract text content and add to context
|
|
if content.get("metadata", {}).get("is_text", False):
|
|
document_context += content.get("data", "")
|
|
|
|
# Try to parse CSV, JSON, or other data files from text
|
|
if doc_name.lower().endswith('.csv'):
|
|
try:
|
|
df = pd.read_csv(io.StringIO(content.get("data", "")))
|
|
df = self._preprocess_dataframe(df)
|
|
data_frames[doc_name] = df
|
|
logger.info(f"Extracted CSV data from {doc_name}: {df.shape}")
|
|
except Exception as e:
|
|
logger.warning(f"Error parsing CSV {doc_name}: {str(e)}")
|
|
|
|
elif doc_name.lower().endswith('.json'):
|
|
try:
|
|
json_data = json.loads(content.get("data", ""))
|
|
if isinstance(json_data, list):
|
|
df = pd.DataFrame(json_data)
|
|
elif isinstance(json_data, dict):
|
|
# Convert nested JSON to DataFrame
|
|
if any(isinstance(v, list) for v in json_data.values()):
|
|
# If lists present, try to use them
|
|
for key, value in json_data.items():
|
|
if isinstance(value, list) and len(value) > 0:
|
|
df = pd.DataFrame(value)
|
|
break
|
|
else:
|
|
continue
|
|
else:
|
|
df = pd.DataFrame([json_data])
|
|
else:
|
|
continue
|
|
|
|
df = self._preprocess_dataframe(df)
|
|
data_frames[doc_name] = df
|
|
logger.info(f"Extracted JSON data from {doc_name}: {df.shape}")
|
|
except Exception as e:
|
|
logger.warning(f"Error parsing JSON {doc_name}: {str(e)}")
|
|
|
|
return data_frames, document_context
|
|
|
|
def _determine_format_type(self, output_label: str) -> str:
|
|
"""
|
|
Determine the format type based on the filename.
|
|
|
|
Args:
|
|
output_label: Output filename
|
|
|
|
Returns:
|
|
Format type (markdown, html, text, png, etc.)
|
|
"""
|
|
output_label_lower = output_label.lower()
|
|
|
|
if output_label_lower.endswith('.md'):
|
|
return "markdown"
|
|
elif output_label_lower.endswith('.html'):
|
|
return "html"
|
|
elif output_label_lower.endswith('.txt'):
|
|
return "text"
|
|
elif output_label_lower.endswith('.json'):
|
|
return "json"
|
|
elif output_label_lower.endswith('.csv'):
|
|
return "csv"
|
|
elif output_label_lower.endswith('.png'):
|
|
return "png"
|
|
elif output_label_lower.endswith('.jpg') or output_label_lower.endswith('.jpeg'):
|
|
return "jpg"
|
|
elif output_label_lower.endswith('.svg'):
|
|
return "svg"
|
|
else:
|
|
# Default to markdown
|
|
return "markdown"
|
|
|
|
def _preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Perform basic preprocessing for a DataFrame"""
|
|
if df.empty:
|
|
return df
|
|
|
|
# Remove completely empty rows and columns
|
|
df = df.dropna(how='all')
|
|
df = df.dropna(axis=1, how='all')
|
|
|
|
# String conversion to numeric values where appropriate
|
|
for col in df.columns:
|
|
# Skip if already numeric
|
|
if pd.api.types.is_numeric_dtype(df[col]):
|
|
continue
|
|
|
|
# Skip if predominantly non-numeric strings
|
|
if df[col].dtype == 'object':
|
|
# Check if more than 80% of non-NA values could be numeric
|
|
non_na_values = df[col].dropna()
|
|
if len(non_na_values) == 0:
|
|
continue
|
|
|
|
# Attempt conversion to numeric values
|
|
numeric_count = pd.to_numeric(non_na_values, errors='coerce').notna().sum()
|
|
if numeric_count / len(non_na_values) > 0.8:
|
|
# More than 80% can be converted to numeric values
|
|
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
|
|
return df
|
|
|
|
def _determine_analysis_type(self, task: str) -> str:
|
|
"""
|
|
Determine the analysis type based on the task.
|
|
|
|
Args:
|
|
task: The analysis task
|
|
|
|
Returns:
|
|
Analysis type
|
|
"""
|
|
# Using universal patterns rather than language-specific keywords
|
|
task_lower = task.lower()
|
|
|
|
# Check for statistical analysis
|
|
if "statistical" in task_lower or "stats" in task_lower:
|
|
return "statistical"
|
|
|
|
# Check for trend analysis
|
|
elif "trend" in task_lower or "time series" in task_lower:
|
|
return "trend"
|
|
|
|
# Check for comparative analysis
|
|
elif "compare" in task_lower or "comparison" in task_lower or "vs" in task_lower:
|
|
return "comparative"
|
|
|
|
# Check for predictive analysis
|
|
elif "predict" in task_lower or "forecast" in task_lower:
|
|
return "predictive"
|
|
|
|
# Check for clustering or categorization
|
|
elif "cluster" in task_lower or "segment" in task_lower or "classify" in task_lower:
|
|
return "clustering"
|
|
|
|
# Default: general analysis
|
|
else:
|
|
return "general"
|
|
|
|
def _extract_data_insights(self, data_frames: Dict[str, pd.DataFrame]) -> str:
|
|
"""
|
|
Extract basic insights from DataFrames.
|
|
|
|
Args:
|
|
data_frames: Dictionary of DataFrames
|
|
|
|
Returns:
|
|
Extracted insights as text
|
|
"""
|
|
insights = []
|
|
|
|
for name, df in data_frames.items():
|
|
if df.empty:
|
|
continue
|
|
|
|
insight = f"Dataset: {name}\n"
|
|
insight += f"Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
|
insight += f"Columns: {', '.join(df.columns.tolist())}\n"
|
|
|
|
# Basic statistics for numeric columns
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
if len(numeric_cols) > 0:
|
|
insight += "Statistics for numeric columns:\n"
|
|
for col in numeric_cols[:5]: # Limit to first 5 columns
|
|
stats = df[col].describe()
|
|
insight += f" {col}: min={stats['min']:.2f}, max={stats['max']:.2f}, mean={stats['mean']:.2f}, median={df[col].median():.2f}\n"
|
|
|
|
# Categorical column values
|
|
cat_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
if len(cat_cols) > 0:
|
|
insight += "Categorical columns:\n"
|
|
for col in cat_cols[:3]: # Limit to first 3 columns
|
|
# Get top 3 values
|
|
top_values = df[col].value_counts().head(3)
|
|
vals_str = ", ".join([f"{val} ({count})" for val, count in top_values.items()])
|
|
insight += f" {col}: {df[col].nunique()} unique values. Most common values: {vals_str}\n"
|
|
|
|
insights.append(insight)
|
|
|
|
return "\n\n".join(insights)
|
|
|
|
def _generate_visualization_document(self, data_frames: Dict[str, pd.DataFrame],
|
|
analysis_type: str, prompt: str,
|
|
output_label: str) -> Dict[str, Any]:
|
|
"""
|
|
Generate a visualization document based on the data and analysis type.
|
|
|
|
Args:
|
|
data_frames: Dictionary of DataFrames
|
|
analysis_type: Analysis type
|
|
prompt: Original task description
|
|
output_label: Output filename
|
|
|
|
Returns:
|
|
Visualization document
|
|
"""
|
|
# Determine format from filename
|
|
format_type = output_label.split('.')[-1].lower() if '.' in output_label else 'png'
|
|
|
|
# Set default format if unknown
|
|
if format_type not in ['png', 'jpg', 'jpeg', 'svg']:
|
|
format_type = 'png'
|
|
|
|
# Use first DataFrame for visualization
|
|
if not data_frames:
|
|
return {
|
|
"label": output_label,
|
|
"content": "No data available for visualization."
|
|
}
|
|
|
|
# Get name and DataFrame of first dataset
|
|
name, df = next(iter(data_frames.items()))
|
|
|
|
# Create different visualization types based on analysis type and data
|
|
plt.figure(figsize=self.default_figsize)
|
|
|
|
if analysis_type == "statistical":
|
|
# Statistical visualization
|
|
self._create_statistical_visualization(df, name)
|
|
elif analysis_type == "trend":
|
|
# Trend visualization
|
|
self._create_trend_visualization(df, name)
|
|
elif analysis_type == "comparative":
|
|
# Comparative visualization
|
|
self._create_comparative_visualization(df, name)
|
|
elif analysis_type == "predictive":
|
|
# Predictive visualization (simple example)
|
|
self._create_predictive_visualization(df, name)
|
|
elif analysis_type == "clustering":
|
|
# Clustering visualization
|
|
self._create_clustering_visualization(df, name)
|
|
else:
|
|
# General visualization
|
|
self._create_general_visualization(df, name)
|
|
|
|
# Save figure as Base64 string
|
|
img_data = self._get_figure_as_base64(format_type)
|
|
plt.close()
|
|
|
|
# Prepare content for document based on format
|
|
if format_type in ['png', 'jpg', 'jpeg']:
|
|
content_str = img_data
|
|
elif format_type == 'svg':
|
|
# SVG content as text
|
|
buffer = io.StringIO()
|
|
plt.savefig(buffer, format='svg')
|
|
content_str = buffer.getvalue()
|
|
buffer.close()
|
|
else:
|
|
# Fallback to PNG
|
|
content_str = img_data
|
|
|
|
return {
|
|
"label": output_label,
|
|
"content": content_str
|
|
}
|
|
|
|
def _create_statistical_visualization(self, df: pd.DataFrame, name: str):
|
|
"""Create a statistical visualization for a DataFrame"""
|
|
# Choose numeric columns for display
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:4] # Limit to first 4
|
|
|
|
if len(numeric_cols) == 0:
|
|
plt.text(0.5, 0.5, "No numeric data found for statistical visualization",
|
|
ha='center', va='center', fontsize=12)
|
|
return
|
|
|
|
# Visualize distribution of first numeric column
|
|
main_col = numeric_cols[0]
|
|
|
|
# Create histogram with KDE
|
|
sns.histplot(df[main_col].dropna(), kde=True)
|
|
plt.title(f'Distribution of {main_col} - {name}')
|
|
plt.xlabel(main_col)
|
|
plt.ylabel('Frequency')
|
|
plt.tight_layout()
|
|
|
|
def _create_trend_visualization(self, df: pd.DataFrame, name: str):
|
|
"""Create a trend visualization for a DataFrame"""
|
|
# Choose numeric columns for display
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
|
|
|
|
if len(numeric_cols) == 0:
|
|
plt.text(0.5, 0.5, "No numeric data found for trend visualization",
|
|
ha='center', va='center', fontsize=12)
|
|
return
|
|
|
|
# Look for date index or use running index
|
|
date_col = None
|
|
for col in df.columns:
|
|
if pd.api.types.is_datetime64_dtype(df[col]) or 'date' in col.lower() or 'time' in col.lower():
|
|
date_col = col
|
|
break
|
|
|
|
# Use date column as X-axis if available
|
|
if date_col:
|
|
for col in numeric_cols:
|
|
plt.plot(df[date_col], df[col], marker='o', linestyle='-', label=col)
|
|
else:
|
|
# Otherwise use index numbers
|
|
for col in numeric_cols:
|
|
plt.plot(range(len(df)), df[col], marker='o', linestyle='-', label=col)
|
|
|
|
plt.title(f'Trend Analysis - {name}')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
|
|
def _create_comparative_visualization(self, df: pd.DataFrame, name: str):
|
|
"""Create a comparative visualization for a DataFrame"""
|
|
# Choose numeric columns for display
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:4] # Limit to first 4
|
|
|
|
if len(numeric_cols) == 0:
|
|
plt.text(0.5, 0.5, "No numeric data found for comparative visualization",
|
|
ha='center', va='center', fontsize=12)
|
|
return
|
|
|
|
# Find categorical column for grouping
|
|
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
if len(categorical_cols) > 0:
|
|
category_col = categorical_cols[0]
|
|
|
|
# Display maximum of first 7 categories
|
|
top_categories = df[category_col].value_counts().head(7).index
|
|
filtered_df = df[df[category_col].isin(top_categories)]
|
|
|
|
# Create grouped bar chart
|
|
numeric_col = numeric_cols[0]
|
|
sns.barplot(x=category_col, y=numeric_col, data=filtered_df)
|
|
plt.title(f'Comparison of {numeric_col} by {category_col} - {name}')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
else:
|
|
# Comparative visualization for numeric columns without categories
|
|
if len(numeric_cols) >= 2:
|
|
# Scatter plot for first two numeric columns
|
|
sns.scatterplot(x=numeric_cols[0], y=numeric_cols[1], data=df)
|
|
plt.title(f'Comparison of {numeric_cols[0]} vs {numeric_cols[1]} - {name}')
|
|
plt.tight_layout()
|
|
else:
|
|
# Simple bar chart for a single numeric column
|
|
plt.bar(range(min(20, len(df))), df[numeric_cols[0]].head(20))
|
|
plt.title(f'Top 20 Values for {numeric_cols[0]} - {name}')
|
|
plt.tight_layout()
|
|
|
|
def _create_predictive_visualization(self, df: pd.DataFrame, name: str):
|
|
"""Create a simple predictive visualization for a DataFrame"""
|
|
# Choose numeric columns for display
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
|
|
|
|
if len(numeric_cols) < 2:
|
|
plt.text(0.5, 0.5, "At least 2 numeric columns required for predictive visualization",
|
|
ha='center', va='center', fontsize=12)
|
|
return
|
|
|
|
# Simple scatter plot with trend line
|
|
x = df[numeric_cols[0]].values
|
|
y = df[numeric_cols[1]].values
|
|
|
|
# Linear regression with NumPy
|
|
valid_indices = ~(np.isnan(x) | np.isnan(y))
|
|
if np.sum(valid_indices) > 1: # At least 2 valid data points
|
|
x_valid = x[valid_indices].reshape(-1, 1)
|
|
y_valid = y[valid_indices]
|
|
|
|
# Linear regression with NumPy polyfit
|
|
if len(x_valid) > 1:
|
|
coeffs = np.polyfit(x_valid.flatten(), y_valid, 1)
|
|
poly_func = np.poly1d(coeffs)
|
|
|
|
# Create prediction line
|
|
x_line = np.linspace(np.min(x_valid), np.max(x_valid), 100).reshape(-1, 1)
|
|
y_pred = poly_func(x_line)
|
|
|
|
# Create scatter plot with trend line
|
|
plt.scatter(x_valid, y_valid, alpha=0.7)
|
|
plt.plot(x_line, y_pred, 'r-', linewidth=2)
|
|
plt.title(f'Linear Regression: {numeric_cols[1]} vs {numeric_cols[0]} - {name}')
|
|
plt.xlabel(numeric_cols[0])
|
|
plt.ylabel(numeric_cols[1])
|
|
plt.tight_layout()
|
|
else:
|
|
plt.text(0.5, 0.5, "Insufficient data for predictive analysis",
|
|
ha='center', va='center', fontsize=12)
|
|
|
|
def _create_clustering_visualization(self, df: pd.DataFrame, name: str):
|
|
"""Create a clustering visualization for a DataFrame"""
|
|
# Choose numeric columns for display
|
|
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
|
|
|
|
if len(numeric_cols) < 2:
|
|
plt.text(0.5, 0.5, "At least 2 numeric columns required for clustering visualization",
|
|
ha='center', va='center', fontsize=12)
|
|
return
|
|
|
|
# Extract data for first two numeric columns
|
|
x = df[numeric_cols[0]].values
|
|
y = df[numeric_cols[1]].values
|
|
|
|
# Find categorical column for color coding
|
|
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
|
|
|
|
if len(categorical_cols) > 0:
|
|
# Use first categorical column for color coding
|
|
category_col = categorical_cols[0]
|
|
categories = df[category_col].astype('category').cat.codes
|
|
|
|
# Create scatter plot with color coding by category
|
|
plt.scatter(x, y, c=categories, cmap='viridis', alpha=0.7)
|
|
plt.colorbar(label=category_col)
|
|
else:
|
|
# Simple scatter plot without color coding
|
|
plt.scatter(x, y, alpha=0.7)
|
|
|
|
plt.title(f'Clustering Visualization: {numeric_cols[1]} vs {numeric_cols[0]} - {name}')
|
|
plt.xlabel(numeric_cols[0])
|
|
plt.ylabel(numeric_cols[1])
|
|
plt.tight_layout()
|
|
|
|
def _create_general_visualization(self, df: pd.DataFrame, name: str):
|
|
"""Create a general visualization for a DataFrame"""
|
|
# Choose numeric columns for display
|
|
numeric_cols = df.select_dtypes(include=['number']).columns
|
|
|
|
if len(numeric_cols) == 0:
|
|
plt.text(0.5, 0.5, "No numeric data found for visualization",
|
|
ha='center', va='center', fontsize=12)
|
|
return
|
|
|
|
# Create correlation matrix if multiple numeric columns available
|
|
if len(numeric_cols) >= 2:
|
|
corr_matrix = df[numeric_cols].corr()
|
|
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
|
|
plt.title(f'Correlation Matrix - {name}')
|
|
else:
|
|
# Simple distribution for a single numeric column
|
|
sns.histplot(df[numeric_cols[0]].dropna(), kde=True)
|
|
plt.title(f'Distribution of {numeric_cols[0]} - {name}')
|
|
|
|
plt.tight_layout()
|
|
|
|
def _get_figure_as_base64(self, format_type: str = 'png') -> str:
|
|
"""
|
|
Convert current matplotlib figure to base64 string.
|
|
|
|
Args:
|
|
format_type: Image format (png, jpg, svg)
|
|
|
|
Returns:
|
|
Base64 encoded string of the figure
|
|
"""
|
|
buffer = io.BytesIO()
|
|
plt.savefig(buffer, format=format_type, dpi=self.chart_dpi)
|
|
buffer.seek(0)
|
|
image_data = buffer.getvalue()
|
|
buffer.close()
|
|
|
|
# Convert to base64
|
|
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
return image_base64
|
|
|
|
async def _generate_analysis_document(self, prompt: str, context: str, data_insights: str,
|
|
analysis_type: str, format_type: str,
|
|
output_label: str, output_description: str) -> str:
|
|
"""
|
|
Generate an analysis document based on the data and prompt.
|
|
|
|
Args:
|
|
prompt: Task description
|
|
context: Document context as text
|
|
data_insights: Insights from the data
|
|
analysis_type: Analysis type
|
|
format_type: Output format
|
|
output_label: Output filename
|
|
output_description: Description of desired output
|
|
|
|
Returns:
|
|
Generated document content
|
|
"""
|
|
if not self.ai_service:
|
|
return f"# Data Analysis ({analysis_type})\n\nAnalysis could not be generated: AI service not available."
|
|
|
|
# Create specialized prompt based on analysis type
|
|
system_prompt = f"""
|
|
You are a specialized data analyst focused on {analysis_type} analyses.
|
|
|
|
Create a detailed analysis of the provided data and/or text content.
|
|
Your analysis should include:
|
|
1. A summary of the data/content
|
|
2. Key findings and insights
|
|
3. Supporting evidence and calculations
|
|
4. Clear conclusions
|
|
5. Recommendations where appropriate
|
|
|
|
Format the analysis in the requested output format.
|
|
"""
|
|
|
|
# Create extended prompt with all available information
|
|
generation_prompt = f"""
|
|
Create a detailed {analysis_type} analysis for the following task:
|
|
|
|
TASK:
|
|
{prompt}
|
|
|
|
CONTEXT:
|
|
{context if context else 'No additional context available.'}
|
|
|
|
DATA INSIGHTS:
|
|
{data_insights if data_insights else 'No data insights available.'}
|
|
|
|
OUTPUT REQUIREMENTS:
|
|
- Filename: {output_label}
|
|
- Description: {output_description}
|
|
- Format: {format_type}
|
|
|
|
The analysis should be professional and clearly structured, considering all available information.
|
|
|
|
The output must perfectly match the {format_type} format.
|
|
"""
|
|
|
|
try:
|
|
# Call AI for analysis
|
|
content = await self.ai_service.call_api([
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": generation_prompt}
|
|
])
|
|
|
|
# For markdown format, ensure there's a title at the beginning
|
|
if format_type == "markdown" and not content.strip().startswith("# "):
|
|
content = f"# Data Analysis ({analysis_type})\n\n{content}"
|
|
|
|
return content
|
|
except Exception as e:
|
|
logger.error(f"Error generating analysis: {str(e)}")
|
|
return f"# Data Analysis ({analysis_type})\n\nError generating analysis: {str(e)}"
|
|
|
|
|
|
# Factory function for the Analyst agent
|
|
def get_analyst_agent():
|
|
"""
|
|
Factory function that returns an instance of the Analyst agent.
|
|
|
|
Returns:
|
|
An instance of the Analyst agent
|
|
"""
|
|
return AgentAnalyst() |