gateway/modules/chat_agent_analyst.py
2025-04-20 23:53:37 +02:00

748 lines
No EOL
30 KiB
Python

"""
Data analyst agent for analysis and interpretation of data.
Optimized for the new task-based processing.
"""
import logging
import json
import re
import uuid
import io
import base64
from typing import Dict, Any, List, Optional
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from modules.chat_registry import AgentBase
logger = logging.getLogger(__name__)
class AgentAnalyst(AgentBase):
"""Agent for analysis and interpretation of data"""
def __init__(self):
"""Initialize the data analysis agent"""
super().__init__()
self.name = "analyst"
self.description = "Analyzes and interprets data using statistical methods and visualizations"
self.capabilities = [
"data_analysis",
"pattern_recognition",
"statistics",
"visualization",
"data_interpretation"
]
# Visualization settings
self.plt_style = 'seaborn-v0_8-whitegrid'
self.default_figsize = (10, 6)
self.chart_dpi = 100
plt.style.use(self.plt_style)
async def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a standardized task structure and perform data analysis.
Args:
task: A dictionary containing:
- task_id: Unique ID for this task
- prompt: The main instruction for the agent
- input_documents: List of documents to process
- output_specifications: List of required output documents
- context: Additional contextual information
Returns:
A dictionary containing:
- feedback: Text response explaining the analysis results
- documents: List of created document objects
"""
try:
# Extract relevant task information
prompt = task.get("prompt", "")
input_documents = task.get("input_documents", [])
output_specs = task.get("output_specifications", [])
# Check if AI service is available
if not self.ai_service:
logger.error("No AI service configured for the Analyst agent")
return {
"feedback": "The Analyst agent is not properly configured.",
"documents": []
}
# Extract data from input documents
data_frames, document_context = self._extract_data_from_documents(input_documents)
# Check if we have analyzable content
have_analyzable_content = len(data_frames) > 0 or (prompt and len(prompt.strip()) > 10)
if not have_analyzable_content:
# Warning if no analyzable content available
logger.warning("No analyzable content found")
feedback = "I couldn't find any processable data in the provided documents."
return {
"feedback": feedback,
"documents": []
}
# Determine analysis type
analysis_type = self._determine_analysis_type(prompt)
logger.info(f"Performing {analysis_type} analysis")
# Store generated documents
generated_documents = []
# Extract data insights if DataFrames are available
data_insights = ""
if data_frames:
data_insights = self._extract_data_insights(data_frames)
logger.info(f"Extracted insights from {len(data_frames)} datasets")
# Generate an appropriate document for each requested output
for spec in output_specs:
output_label = spec.get("label", "")
output_description = spec.get("description", "")
# Determine format based on file extension
format_type = self._determine_format_type(output_label)
# Special handling for visualizations if required
if "chart" in output_label.lower() or "plot" in output_label.lower() or "visualization" in output_label.lower() or format_type in ["png", "jpg", "svg"]:
# Generate visualization document if data available
if data_frames:
viz_document = self._generate_visualization_document(data_frames, analysis_type, prompt, output_label)
generated_documents.append(viz_document)
else:
# Fallback if no data
generated_documents.append({
"label": output_label,
"content": "No data available for visualization."
})
else:
# Create text-based analysis
content = await self._generate_analysis_document(
prompt,
document_context,
data_insights,
analysis_type,
format_type,
output_label,
output_description
)
generated_documents.append({
"label": output_label,
"content": content
})
# If no specific outputs requested, create standard documents
if not output_specs:
# Standard analysis
analysis_content = await self._generate_analysis_document(
prompt,
document_context,
data_insights,
analysis_type,
"markdown",
"analysis_report.md",
"Analysis report"
)
generated_documents.append({
"label": "analysis_report.md",
"content": analysis_content
})
# Add visualization if data available
if data_frames:
viz_document = self._generate_visualization_document(data_frames, analysis_type, prompt, "data_visualization.png")
generated_documents.append(viz_document)
# Create feedback
if data_frames:
feedback = f"I analyzed {len(data_frames)} datasets and created {len(generated_documents)} documents with the results."
else:
feedback = f"I performed a text analysis and created {len(generated_documents)} documents with the results."
return {
"feedback": feedback,
"documents": generated_documents
}
except Exception as e:
error_msg = f"Error during data analysis: {str(e)}"
logger.error(error_msg)
return {
"feedback": f"An error occurred during data analysis: {str(e)}",
"documents": []
}
def _extract_data_from_documents(self, documents: List[Dict[str, Any]]) -> tuple:
"""
Extract data from input documents.
Args:
documents: List of input documents
Returns:
Tuple of (Dictionary of DataFrames, Document context text)
"""
data_frames = {}
document_context = ""
for doc in documents:
doc_name = doc.get("name", "unnamed")
document_context += f"\n\n--- {doc_name} ---\n"
for content in doc.get("contents", []):
# Extract text content and add to context
if content.get("metadata", {}).get("is_text", False):
document_context += content.get("data", "")
# Try to parse CSV, JSON, or other data files from text
if doc_name.lower().endswith('.csv'):
try:
df = pd.read_csv(io.StringIO(content.get("data", "")))
df = self._preprocess_dataframe(df)
data_frames[doc_name] = df
logger.info(f"Extracted CSV data from {doc_name}: {df.shape}")
except Exception as e:
logger.warning(f"Error parsing CSV {doc_name}: {str(e)}")
elif doc_name.lower().endswith('.json'):
try:
json_data = json.loads(content.get("data", ""))
if isinstance(json_data, list):
df = pd.DataFrame(json_data)
elif isinstance(json_data, dict):
# Convert nested JSON to DataFrame
if any(isinstance(v, list) for v in json_data.values()):
# If lists present, try to use them
for key, value in json_data.items():
if isinstance(value, list) and len(value) > 0:
df = pd.DataFrame(value)
break
else:
continue
else:
df = pd.DataFrame([json_data])
else:
continue
df = self._preprocess_dataframe(df)
data_frames[doc_name] = df
logger.info(f"Extracted JSON data from {doc_name}: {df.shape}")
except Exception as e:
logger.warning(f"Error parsing JSON {doc_name}: {str(e)}")
return data_frames, document_context
def _determine_format_type(self, output_label: str) -> str:
"""
Determine the format type based on the filename.
Args:
output_label: Output filename
Returns:
Format type (markdown, html, text, png, etc.)
"""
output_label_lower = output_label.lower()
if output_label_lower.endswith('.md'):
return "markdown"
elif output_label_lower.endswith('.html'):
return "html"
elif output_label_lower.endswith('.txt'):
return "text"
elif output_label_lower.endswith('.json'):
return "json"
elif output_label_lower.endswith('.csv'):
return "csv"
elif output_label_lower.endswith('.png'):
return "png"
elif output_label_lower.endswith('.jpg') or output_label_lower.endswith('.jpeg'):
return "jpg"
elif output_label_lower.endswith('.svg'):
return "svg"
else:
# Default to markdown
return "markdown"
def _preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""Perform basic preprocessing for a DataFrame"""
if df.empty:
return df
# Remove completely empty rows and columns
df = df.dropna(how='all')
df = df.dropna(axis=1, how='all')
# String conversion to numeric values where appropriate
for col in df.columns:
# Skip if already numeric
if pd.api.types.is_numeric_dtype(df[col]):
continue
# Skip if predominantly non-numeric strings
if df[col].dtype == 'object':
# Check if more than 80% of non-NA values could be numeric
non_na_values = df[col].dropna()
if len(non_na_values) == 0:
continue
# Attempt conversion to numeric values
numeric_count = pd.to_numeric(non_na_values, errors='coerce').notna().sum()
if numeric_count / len(non_na_values) > 0.8:
# More than 80% can be converted to numeric values
df[col] = pd.to_numeric(df[col], errors='coerce')
return df
def _determine_analysis_type(self, task: str) -> str:
"""
Determine the analysis type based on the task.
Args:
task: The analysis task
Returns:
Analysis type
"""
# Using universal patterns rather than language-specific keywords
task_lower = task.lower()
# Check for statistical analysis
if "statistical" in task_lower or "stats" in task_lower:
return "statistical"
# Check for trend analysis
elif "trend" in task_lower or "time series" in task_lower:
return "trend"
# Check for comparative analysis
elif "compare" in task_lower or "comparison" in task_lower or "vs" in task_lower:
return "comparative"
# Check for predictive analysis
elif "predict" in task_lower or "forecast" in task_lower:
return "predictive"
# Check for clustering or categorization
elif "cluster" in task_lower or "segment" in task_lower or "classify" in task_lower:
return "clustering"
# Default: general analysis
else:
return "general"
def _extract_data_insights(self, data_frames: Dict[str, pd.DataFrame]) -> str:
"""
Extract basic insights from DataFrames.
Args:
data_frames: Dictionary of DataFrames
Returns:
Extracted insights as text
"""
insights = []
for name, df in data_frames.items():
if df.empty:
continue
insight = f"Dataset: {name}\n"
insight += f"Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
insight += f"Columns: {', '.join(df.columns.tolist())}\n"
# Basic statistics for numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) > 0:
insight += "Statistics for numeric columns:\n"
for col in numeric_cols[:5]: # Limit to first 5 columns
stats = df[col].describe()
insight += f" {col}: min={stats['min']:.2f}, max={stats['max']:.2f}, mean={stats['mean']:.2f}, median={df[col].median():.2f}\n"
# Categorical column values
cat_cols = df.select_dtypes(include=['object', 'category']).columns
if len(cat_cols) > 0:
insight += "Categorical columns:\n"
for col in cat_cols[:3]: # Limit to first 3 columns
# Get top 3 values
top_values = df[col].value_counts().head(3)
vals_str = ", ".join([f"{val} ({count})" for val, count in top_values.items()])
insight += f" {col}: {df[col].nunique()} unique values. Most common values: {vals_str}\n"
insights.append(insight)
return "\n\n".join(insights)
def _generate_visualization_document(self, data_frames: Dict[str, pd.DataFrame],
analysis_type: str, prompt: str,
output_label: str) -> Dict[str, Any]:
"""
Generate a visualization document based on the data and analysis type.
Args:
data_frames: Dictionary of DataFrames
analysis_type: Analysis type
prompt: Original task description
output_label: Output filename
Returns:
Visualization document
"""
# Determine format from filename
format_type = output_label.split('.')[-1].lower() if '.' in output_label else 'png'
# Set default format if unknown
if format_type not in ['png', 'jpg', 'jpeg', 'svg']:
format_type = 'png'
# Use first DataFrame for visualization
if not data_frames:
return {
"label": output_label,
"content": "No data available for visualization."
}
# Get name and DataFrame of first dataset
name, df = next(iter(data_frames.items()))
# Create different visualization types based on analysis type and data
plt.figure(figsize=self.default_figsize)
if analysis_type == "statistical":
# Statistical visualization
self._create_statistical_visualization(df, name)
elif analysis_type == "trend":
# Trend visualization
self._create_trend_visualization(df, name)
elif analysis_type == "comparative":
# Comparative visualization
self._create_comparative_visualization(df, name)
elif analysis_type == "predictive":
# Predictive visualization (simple example)
self._create_predictive_visualization(df, name)
elif analysis_type == "clustering":
# Clustering visualization
self._create_clustering_visualization(df, name)
else:
# General visualization
self._create_general_visualization(df, name)
# Save figure as Base64 string
img_data = self._get_figure_as_base64(format_type)
plt.close()
# Prepare content for document based on format
if format_type in ['png', 'jpg', 'jpeg']:
content_str = img_data
elif format_type == 'svg':
# SVG content as text
buffer = io.StringIO()
plt.savefig(buffer, format='svg')
content_str = buffer.getvalue()
buffer.close()
else:
# Fallback to PNG
content_str = img_data
return {
"label": output_label,
"content": content_str
}
def _create_statistical_visualization(self, df: pd.DataFrame, name: str):
"""Create a statistical visualization for a DataFrame"""
# Choose numeric columns for display
numeric_cols = df.select_dtypes(include=['number']).columns[:4] # Limit to first 4
if len(numeric_cols) == 0:
plt.text(0.5, 0.5, "No numeric data found for statistical visualization",
ha='center', va='center', fontsize=12)
return
# Visualize distribution of first numeric column
main_col = numeric_cols[0]
# Create histogram with KDE
sns.histplot(df[main_col].dropna(), kde=True)
plt.title(f'Distribution of {main_col} - {name}')
plt.xlabel(main_col)
plt.ylabel('Frequency')
plt.tight_layout()
def _create_trend_visualization(self, df: pd.DataFrame, name: str):
"""Create a trend visualization for a DataFrame"""
# Choose numeric columns for display
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
if len(numeric_cols) == 0:
plt.text(0.5, 0.5, "No numeric data found for trend visualization",
ha='center', va='center', fontsize=12)
return
# Look for date index or use running index
date_col = None
for col in df.columns:
if pd.api.types.is_datetime64_dtype(df[col]) or 'date' in col.lower() or 'time' in col.lower():
date_col = col
break
# Use date column as X-axis if available
if date_col:
for col in numeric_cols:
plt.plot(df[date_col], df[col], marker='o', linestyle='-', label=col)
else:
# Otherwise use index numbers
for col in numeric_cols:
plt.plot(range(len(df)), df[col], marker='o', linestyle='-', label=col)
plt.title(f'Trend Analysis - {name}')
plt.legend()
plt.grid(True)
plt.tight_layout()
def _create_comparative_visualization(self, df: pd.DataFrame, name: str):
"""Create a comparative visualization for a DataFrame"""
# Choose numeric columns for display
numeric_cols = df.select_dtypes(include=['number']).columns[:4] # Limit to first 4
if len(numeric_cols) == 0:
plt.text(0.5, 0.5, "No numeric data found for comparative visualization",
ha='center', va='center', fontsize=12)
return
# Find categorical column for grouping
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
if len(categorical_cols) > 0:
category_col = categorical_cols[0]
# Display maximum of first 7 categories
top_categories = df[category_col].value_counts().head(7).index
filtered_df = df[df[category_col].isin(top_categories)]
# Create grouped bar chart
numeric_col = numeric_cols[0]
sns.barplot(x=category_col, y=numeric_col, data=filtered_df)
plt.title(f'Comparison of {numeric_col} by {category_col} - {name}')
plt.xticks(rotation=45)
plt.tight_layout()
else:
# Comparative visualization for numeric columns without categories
if len(numeric_cols) >= 2:
# Scatter plot for first two numeric columns
sns.scatterplot(x=numeric_cols[0], y=numeric_cols[1], data=df)
plt.title(f'Comparison of {numeric_cols[0]} vs {numeric_cols[1]} - {name}')
plt.tight_layout()
else:
# Simple bar chart for a single numeric column
plt.bar(range(min(20, len(df))), df[numeric_cols[0]].head(20))
plt.title(f'Top 20 Values for {numeric_cols[0]} - {name}')
plt.tight_layout()
def _create_predictive_visualization(self, df: pd.DataFrame, name: str):
"""Create a simple predictive visualization for a DataFrame"""
# Choose numeric columns for display
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
if len(numeric_cols) < 2:
plt.text(0.5, 0.5, "At least 2 numeric columns required for predictive visualization",
ha='center', va='center', fontsize=12)
return
# Simple scatter plot with trend line
x = df[numeric_cols[0]].values
y = df[numeric_cols[1]].values
# Linear regression with NumPy
valid_indices = ~(np.isnan(x) | np.isnan(y))
if np.sum(valid_indices) > 1: # At least 2 valid data points
x_valid = x[valid_indices].reshape(-1, 1)
y_valid = y[valid_indices]
# Linear regression with NumPy polyfit
if len(x_valid) > 1:
coeffs = np.polyfit(x_valid.flatten(), y_valid, 1)
poly_func = np.poly1d(coeffs)
# Create prediction line
x_line = np.linspace(np.min(x_valid), np.max(x_valid), 100).reshape(-1, 1)
y_pred = poly_func(x_line)
# Create scatter plot with trend line
plt.scatter(x_valid, y_valid, alpha=0.7)
plt.plot(x_line, y_pred, 'r-', linewidth=2)
plt.title(f'Linear Regression: {numeric_cols[1]} vs {numeric_cols[0]} - {name}')
plt.xlabel(numeric_cols[0])
plt.ylabel(numeric_cols[1])
plt.tight_layout()
else:
plt.text(0.5, 0.5, "Insufficient data for predictive analysis",
ha='center', va='center', fontsize=12)
def _create_clustering_visualization(self, df: pd.DataFrame, name: str):
"""Create a clustering visualization for a DataFrame"""
# Choose numeric columns for display
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
if len(numeric_cols) < 2:
plt.text(0.5, 0.5, "At least 2 numeric columns required for clustering visualization",
ha='center', va='center', fontsize=12)
return
# Extract data for first two numeric columns
x = df[numeric_cols[0]].values
y = df[numeric_cols[1]].values
# Find categorical column for color coding
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
if len(categorical_cols) > 0:
# Use first categorical column for color coding
category_col = categorical_cols[0]
categories = df[category_col].astype('category').cat.codes
# Create scatter plot with color coding by category
plt.scatter(x, y, c=categories, cmap='viridis', alpha=0.7)
plt.colorbar(label=category_col)
else:
# Simple scatter plot without color coding
plt.scatter(x, y, alpha=0.7)
plt.title(f'Clustering Visualization: {numeric_cols[1]} vs {numeric_cols[0]} - {name}')
plt.xlabel(numeric_cols[0])
plt.ylabel(numeric_cols[1])
plt.tight_layout()
def _create_general_visualization(self, df: pd.DataFrame, name: str):
"""Create a general visualization for a DataFrame"""
# Choose numeric columns for display
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) == 0:
plt.text(0.5, 0.5, "No numeric data found for visualization",
ha='center', va='center', fontsize=12)
return
# Create correlation matrix if multiple numeric columns available
if len(numeric_cols) >= 2:
corr_matrix = df[numeric_cols].corr()
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
plt.title(f'Correlation Matrix - {name}')
else:
# Simple distribution for a single numeric column
sns.histplot(df[numeric_cols[0]].dropna(), kde=True)
plt.title(f'Distribution of {numeric_cols[0]} - {name}')
plt.tight_layout()
def _get_figure_as_base64(self, format_type: str = 'png') -> str:
"""
Convert current matplotlib figure to base64 string.
Args:
format_type: Image format (png, jpg, svg)
Returns:
Base64 encoded string of the figure
"""
buffer = io.BytesIO()
plt.savefig(buffer, format=format_type, dpi=self.chart_dpi)
buffer.seek(0)
image_data = buffer.getvalue()
buffer.close()
# Convert to base64
image_base64 = base64.b64encode(image_data).decode('utf-8')
return image_base64
async def _generate_analysis_document(self, prompt: str, context: str, data_insights: str,
analysis_type: str, format_type: str,
output_label: str, output_description: str) -> str:
"""
Generate an analysis document based on the data and prompt.
Args:
prompt: Task description
context: Document context as text
data_insights: Insights from the data
analysis_type: Analysis type
format_type: Output format
output_label: Output filename
output_description: Description of desired output
Returns:
Generated document content
"""
if not self.ai_service:
return f"# Data Analysis ({analysis_type})\n\nAnalysis could not be generated: AI service not available."
# Create specialized prompt based on analysis type
system_prompt = f"""
You are a specialized data analyst focused on {analysis_type} analyses.
Create a detailed analysis of the provided data and/or text content.
Your analysis should include:
1. A summary of the data/content
2. Key findings and insights
3. Supporting evidence and calculations
4. Clear conclusions
5. Recommendations where appropriate
Format the analysis in the requested output format.
"""
# Create extended prompt with all available information
generation_prompt = f"""
Create a detailed {analysis_type} analysis for the following task:
TASK:
{prompt}
CONTEXT:
{context if context else 'No additional context available.'}
DATA INSIGHTS:
{data_insights if data_insights else 'No data insights available.'}
OUTPUT REQUIREMENTS:
- Filename: {output_label}
- Description: {output_description}
- Format: {format_type}
The analysis should be professional and clearly structured, considering all available information.
The output must perfectly match the {format_type} format.
"""
try:
# Call AI for analysis
content = await self.ai_service.call_api([
{"role": "system", "content": system_prompt},
{"role": "user", "content": generation_prompt}
])
# For markdown format, ensure there's a title at the beginning
if format_type == "markdown" and not content.strip().startswith("# "):
content = f"# Data Analysis ({analysis_type})\n\n{content}"
return content
except Exception as e:
logger.error(f"Error generating analysis: {str(e)}")
return f"# Data Analysis ({analysis_type})\n\nError generating analysis: {str(e)}"
# Factory function for the Analyst agent
def get_analyst_agent():
"""
Factory function that returns an instance of the Analyst agent.
Returns:
An instance of the Analyst agent
"""
return AgentAnalyst()