gateway/gwserver/modules/agentservice_agent_analyst.py
2025-04-16 10:49:27 +02:00

1804 lines
No EOL
75 KiB
Python

"""
Datenanalyst-Agent für die Analyse und Interpretation von Daten.
Angepasst für das refaktorisierte Core-Modul mit AgentCommunicationProtocol.
"""
import logging
import traceback
import json
import re
import uuid
import io
import base64
from typing import List, Dict, Any, Optional, Union, Tuple
from datetime import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from modules.agentservice_base import BaseAgent
from connectors.connector_aichat_openai import ChatService
from modules.agentservice_utils import WorkflowUtils, MessageUtils, LoggingUtils, FileUtils
from modules.agentservice_protocol import AgentMessage, AgentCommunicationProtocol
logger = logging.getLogger(__name__)
class AnalystAgent(BaseAgent):
"""Agent for data analysis and interpretation"""
def __init__(self):
"""Initialize the data analyst agent"""
super().__init__()
self.id = "analyst_agent"
self.name = "Data Analyst"
self.type = "analyst"
self.description = "Analyzes and interprets data"
self.capabilities = "data_analysis,pattern_recognition,statistics,visualization,data_interpretation"
self.result_format = "AnalysisReport"
# Initialize AI service
self.ai_service = None
# Document capabilities
self.supports_documents = True
self.document_capabilities = ["read", "analyze", "extract"]
self.required_context = ["data_source", "analysis_objectives"]
self.document_handler = None
# Initialize protocol
self.protocol = AgentCommunicationProtocol()
# Initialize utilities
self.message_utils = MessageUtils()
self.file_utils = FileUtils()
# Setup visualization defaults
self.plt_style = 'seaborn-v0_8-whitegrid'
self.default_figsize = (10, 6)
self.chart_dpi = 100
plt.style.use(self.plt_style)
def get_agent_info(self) -> Dict[str, Any]:
"""Get agent information for agent registry"""
info = super().get_agent_info()
info.update({
"metadata": {
"supported_formats": ["csv", "xlsx", "json", "text"],
"analysis_types": ["statistical", "trend", "comparative", "predictive", "clustering", "general"],
"visualization_types": ["bar", "line", "scatter", "histogram", "box", "heatmap", "pie"]
}
})
return info
def set_document_handler(self, document_handler):
"""Set the document handler for file operations"""
self.document_handler = document_handler
"""
Main updates to the process_message method in AnalystAgent to consider all available content.
"""
async def process_message(self, message: Dict[str, Any], context: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Process a message and perform data analysis.
Args:
message: Input message
context: Optional context
Returns:
Analysis response
"""
# Extract workflow_id from context or message
workflow_id = context.get("workflow_id") if context else message.get("workflow_id", "unknown")
# Get or create logging_utils
log_func = context.get("log_func") if context else None
logging_utils = LoggingUtils(workflow_id, log_func)
# Create status update using protocol
if log_func:
status_message = self.protocol.create_status_update_message(
status_description="Starting data analysis",
sender_id=self.id,
status="in_progress",
progress=0.0,
context_id=workflow_id
)
log_func(workflow_id, status_message.content, "info", self.id, self.name)
# Create response structure
response = {
"role": "assistant",
"content": "",
"agent_id": self.id,
"agent_type": self.type,
"agent_name": self.name,
"result_format": self.result_format,
"workflow_id": workflow_id,
"documents": []
}
try:
# Extract task from message
task = message.get("content", "")
# Process any attached documents and extract data
document_context = ""
data_frames = {}
if message.get("documents"):
logging_utils.info("Processing documents for analysis", "execution")
document_context, data_frames = await self._process_and_extract_data(message)
# Update progress
if log_func:
status_message = self.protocol.create_status_update_message(
status_description="Documents processed, performing analysis",
sender_id=self.id,
status="in_progress",
progress=0.4,
context_id=workflow_id
)
log_func(workflow_id, status_message.content, "info", self.id, self.name)
# Check if we have either data frames OR a substantial text task to analyze
# This is the key change - we're considering the task text as analyzable content
have_analyzable_content = len(data_frames) > 0 or (task and len(task.strip()) > 10)
if not have_analyzable_content:
# Only show warning if really no content is available
if message.get("documents"):
logging_utils.warning("No processable data found in the provided documents", "execution")
analysis_content = "## Data Analysis Report\n\nI couldn't find any processable data in the provided documents. Please ensure you've attached CSV, Excel, or other data files in a format I can analyze."
else:
logging_utils.warning("No documents or analyzable content provided for analysis", "execution")
analysis_content = "## Data Analysis Report\n\nNo data or sufficient text content was provided for analysis. Please provide text for analysis or attach data files for me to analyze."
response["content"] = analysis_content
return response
# Determine analysis type and perform analysis
analysis_type = self._determine_analysis_type(task)
logging_utils.info(f"Performing {analysis_type} analysis", "execution")
# Create enhanced prompt with document context
enhanced_prompt = self._create_enhanced_prompt(message, document_context, context)
# Generate visualization documents if data is available
visualization_documents = []
if data_frames:
logging_utils.info(f"Generating visualizations for {len(data_frames)} data sets", "execution")
visualization_documents = self._generate_visualizations(data_frames, analysis_type, workflow_id, task)
# Add visualizations to response documents
response["documents"].extend(visualization_documents)
# Update progress
if log_func:
status_message = self.protocol.create_status_update_message(
status_description="Visualizations created, finalizing analysis",
sender_id=self.id,
status="in_progress",
progress=0.7,
context_id=workflow_id
)
log_func(workflow_id, status_message.content, "info", self.id, self.name)
# Generate analysis with included data insights if we have data frames
analysis_content = ""
if data_frames:
# Extract data insights to include in the analysis
data_insights = self._extract_data_insights(data_frames)
# Add insights to the prompt
enhanced_prompt += f"\n\n=== DATA INSIGHTS ===\n{data_insights}"
# Generate analysis with data insights
analysis_content = await self._generate_analysis(enhanced_prompt, analysis_type)
# Include references to the visualization documents
if visualization_documents:
viz_references = "\n\n## Visualizations\n\n"
viz_references += "The following visualizations have been created to help understand the data:\n\n"
for i, doc in enumerate(visualization_documents, 1):
doc_source = doc.get("source", {})
doc_name = doc_source.get("name", f"Visualization {i}")
viz_references += f"{i}. **{doc_name}** - Available as an attached document\n"
analysis_content += viz_references
else:
# Generate analysis based just on text if no data frames but we have text to analyze
# This is the key change - we're analyzing the text content directly
logging_utils.info("No data frames available, analyzing text content", "execution")
analysis_content = await self._generate_analysis(enhanced_prompt, analysis_type)
# Final progress update
if log_func:
status_message = self.protocol.create_status_update_message(
status_description="Analysis completed",
sender_id=self.id,
status="completed",
progress=1.0,
context_id=workflow_id
)
log_func(workflow_id, status_message.content, "info", self.id, self.name)
# Set the content in the response
response["content"] = analysis_content
# Finish by sending result message to protocol if needed
if context and context.get("require_protocol_message"):
result_message = self.send_analysis_result(
analysis_content=analysis_content,
sender_id=self.id,
receiver_id=context.get("receiver_id", "workflow"),
task_id=context.get("task_id", f"analysis_{uuid.uuid4()}"),
analysis_data={
"analysis_type": analysis_type,
"visualization_count": len(visualization_documents),
"data_frame_count": len(data_frames)
},
context_id=workflow_id
)
# Just log the message creation, don't need to return it
logging_utils.info(f"Created protocol result message: {result_message.id}", "execution")
return response
except Exception as e:
error_msg = f"Error during data analysis: {str(e)}"
logging_utils.error(error_msg, "error")
# Create error response using protocol
error_message = self.protocol.create_error_message(
error_description=error_msg,
sender_id=self.id,
error_type="analysis",
error_details={"traceback": traceback.format_exc()},
context_id=workflow_id
)
# Set error content in the response
response["content"] = f"## Error during data analysis\n\n{error_msg}\n\n```\n{traceback.format_exc()}\n```"
response["status"] = "error"
return response
"""
Add _create_enhanced_prompt method to better handle text content in analysis.
"""
def _create_enhanced_prompt(self, message: Dict[str, Any], document_context: str, context: Dict[str, Any] = None) -> str:
"""
Create an enhanced prompt for analysis that integrates all available content.
Args:
message: The original message
document_context: Context extracted from documents
context: Optional additional context
Returns:
Enhanced prompt for analysis
"""
# Get original task/prompt
task = message.get("content", "")
# Add context information if available
context_info = ""
if context:
# Add any dependency outputs from previous activities
if "dependency_outputs" in context:
dependency_context = context.get("dependency_outputs", {})
for name, value in dependency_context.items():
if isinstance(value, dict) and "content" in value:
context_info += f"\n\n=== INPUT FROM {name.upper()} ===\n{value['content']}"
else:
context_info += f"\n\n=== INPUT FROM {name.upper()} ===\n{str(value)}"
# Add expected format information
if "expected_format" in context:
context_info += f"\n\nExpected output format: {context.get('expected_format')}"
# Start with task
enhanced_prompt = f"ANALYSIS TASK:\n{task}"
# Add any context information
if context_info:
enhanced_prompt += f"\n\n{context_info}"
# Add document context if available
if document_context:
enhanced_prompt += f"\n\n=== DOCUMENT CONTENT ===\n{document_context}"
else:
# If no document content, explicitly note that we're analyzing the text content directly
enhanced_prompt += "\n\nNo data files were provided. Perform analysis on the text content itself."
# Add final instructions
if document_context:
enhanced_prompt += "\n\nBased on the data and documents provided, please perform a comprehensive analysis."
else:
enhanced_prompt += "\n\nBased on the text content provided, please perform a comprehensive analysis."
if task:
enhanced_prompt += f" Focus specifically on addressing: {task}"
enhanced_prompt += "\n\nProvide insights, patterns, and conclusions in a clear, structured format."
return enhanced_prompt
async def _process_and_extract_data(self, message: Dict[str, Any]) -> Tuple[str, Dict[str, pd.DataFrame]]:
"""
Process documents and extract structured data.
Args:
message: Input message with documents
Returns:
Tuple of (document_context, data_frames_dict)
"""
document_context = ""
data_frames = {}
if not message.get("documents"):
return document_context, data_frames
# Extract document text (this will be our context)
if self.document_handler:
document_context = self.document_handler.merge_document_contents(message)
else:
document_context = self._extract_document_text(message)
# Identify and process data files (CSV, Excel, etc.)
for document in message.get("documents", []):
source = document.get("source", {})
filename = source.get("name", "")
file_id = source.get("id", 0)
content_type = source.get("content_type", "")
# Skip if not a recognizable data file
if not self._is_data_file(filename, content_type):
continue
try:
# Try to get file content through document handler first
file_content = None
if self.document_handler:
file_content = self.document_handler.get_file_content_from_message(message, file_id=file_id)
# Process based on file type
if filename.lower().endswith('.csv'):
df = self._process_csv(file_content, filename)
if df is not None:
data_frames[filename] = df
elif filename.lower().endswith(('.xlsx', '.xls')):
dfs = self._process_excel(file_content, filename)
for sheet_name, df in dfs.items():
data_frames[f"{filename}::{sheet_name}"] = df
elif filename.lower().endswith('.json'):
df = self._process_json(file_content, filename)
if df is not None:
data_frames[filename] = df
except Exception as e:
logger.error(f"Error processing file {filename}: {str(e)}")
return document_context, data_frames
def _is_data_file(self, filename: str, content_type: str) -> bool:
"""Check if a file is a processable data file"""
if filename.lower().endswith(('.csv', '.xlsx', '.xls', '.json')):
return True
if content_type in ['text/csv', 'application/vnd.ms-excel',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/json']:
return True
return False
def _process_csv(self, file_content: Union[bytes, str], filename: str) -> Optional[pd.DataFrame]:
"""Process CSV file content into a pandas DataFrame"""
if file_content is None:
return None
try:
# Handle the case where file_content is already a string
if isinstance(file_content, str):
text_content = file_content
df = pd.read_csv(io.StringIO(text_content))
df = self._preprocess_dataframe(df)
return df
# Handle the case where file_content is bytes
else:
# Try various encodings
for encoding in ['utf-8', 'latin1', 'cp1252']:
try:
# Use StringIO to create a file-like object
text_content = file_content.decode(encoding)
df = pd.read_csv(io.StringIO(text_content))
# Basic preprocessing
df = self._preprocess_dataframe(df)
return df
except UnicodeDecodeError:
continue
except Exception as e:
logger.error(f"Error processing CSV with {encoding} encoding: {str(e)}")
# If all encodings fail, try one more time with errors='replace'
text_content = file_content.decode('utf-8', errors='replace')
df = pd.read_csv(io.StringIO(text_content))
df = self._preprocess_dataframe(df)
return df
except Exception as e:
logger.error(f"Failed to process CSV file {filename}: {str(e)}")
return None
def _process_excel(self, file_content: bytes, filename: str) -> Dict[str, pd.DataFrame]:
"""Process Excel file content into pandas DataFrames"""
result = {}
if file_content is None:
return result
try:
# Use BytesIO to create a file-like object
excel_file = io.BytesIO(file_content)
# Try to read with pandas
excel_data = pd.ExcelFile(excel_file)
# Process each sheet
for sheet_name in excel_data.sheet_names:
df = pd.read_excel(excel_file, sheet_name=sheet_name)
# Basic preprocessing
df = self._preprocess_dataframe(df)
# Only include if there's actual data
if not df.empty:
result[sheet_name] = df
return result
except Exception as e:
logger.error(f"Failed to process Excel file {filename}: {str(e)}")
return result
def _process_json(self, file_content: bytes, filename: str) -> Optional[pd.DataFrame]:
"""Process JSON file content into a pandas DataFrame"""
if file_content is None:
return None
try:
# Decode and parse JSON
json_content = file_content.decode('utf-8')
data = json.loads(json_content)
# Handle different JSON structures
if isinstance(data, list):
# List of records
df = pd.DataFrame(data)
elif isinstance(data, dict):
# Try to find a suitable data structure in the dict
if any(isinstance(v, list) for v in data.values()):
# Find the first list value to use as data
for key, value in data.items():
if isinstance(value, list) and len(value) > 0:
df = pd.DataFrame(value)
break
else:
# No suitable list found
return None
else:
# Convert flat dict to a single-row DataFrame
df = pd.DataFrame([data])
else:
# Unsupported structure
return None
# Basic preprocessing
df = self._preprocess_dataframe(df)
return df
except Exception as e:
logger.error(f"Failed to process JSON file {filename}: {str(e)}")
return None
def _preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""Perform basic preprocessing on a DataFrame"""
if df.empty:
return df
# Remove completely empty rows and columns
df = df.dropna(how='all')
df = df.dropna(axis=1, how='all')
# Try to convert string columns to numeric where appropriate
for col in df.columns:
# Skip if already numeric
if pd.api.types.is_numeric_dtype(df[col]):
continue
# Skip if mostly non-numeric strings
if df[col].dtype == 'object':
# Check if more than 80% of non-NA values could be numeric
non_na_values = df[col].dropna()
if len(non_na_values) == 0:
continue
# Try to convert to numeric and count successes
numeric_count = pd.to_numeric(non_na_values, errors='coerce').notna().sum()
if numeric_count / len(non_na_values) > 0.8:
# More than 80% can be converted to numeric, so convert the column
df[col] = pd.to_numeric(df[col], errors='coerce')
# Try to parse date columns
for col in df.columns:
# Skip if not object dtype
if df[col].dtype != 'object':
continue
# Check if column name suggests a date
if any(date_term in col.lower() for date_term in ['date', 'time', 'day', 'month', 'year']):
try:
# Try to parse as datetime
df[col] = pd.to_datetime(df[col], errors='coerce')
# Only keep the conversion if at least 80% succeeded
if df[col].notna().mean() < 0.8:
# Revert to original if too many NAs were introduced
df[col] = df[col].astype('object')
except:
pass
return df
def _extract_document_text(self, message: Dict[str, Any]) -> str:
"""
Extract text from documents (fallback method).
Args:
message: Input message with documents
Returns:
Extracted text
"""
text_content = ""
for document in message.get("documents", []):
source = document.get("source", {})
name = source.get("name", "unnamed")
text_content += f"\n\n--- {name} ---\n"
for content in document.get("contents", []):
if content.get("type") == "text":
text_content += content.get("text", "")
return text_content
def _determine_analysis_type(self, task: str) -> str:
"""
Determine the type of analysis based on the task.
Enhanced to better handle text-based analysis.
Args:
task: The analysis task
Returns:
Analysis type
"""
task_lower = task.lower()
# Check for statistical analysis
if any(term in task_lower for term in ["statistics", "statistical", "mean", "median", "variance"]):
return "statistical"
# Check for trend analysis
elif any(term in task_lower for term in ["trend", "pattern", "time series", "historical"]):
return "trend"
# Check for comparative analysis
elif any(term in task_lower for term in ["compare", "comparison", "versus", "vs", "difference"]):
return "comparative"
# Check for predictive analysis
elif any(term in task_lower for term in ["predict", "forecast", "future", "projection"]):
return "predictive"
# Check for clustering or categorization
elif any(term in task_lower for term in ["cluster", "segment", "categorize", "classify"]):
return "clustering"
# Check for text analysis specific terms
elif any(term in task_lower for term in ["text", "sentiment", "topic", "semantic", "meaning", "interpretation"]):
return "textual"
# Check for summary requests
elif any(term in task_lower for term in ["summarize", "summary", "overview", "digest"]):
return "summary"
# Default to general analysis
else:
return "general"
def _extract_data_insights(self, data_frames: Dict[str, pd.DataFrame]) -> str:
"""
Extract basic insights from data frames.
Args:
data_frames: Dictionary of data frames
Returns:
Extracted insights as text
"""
insights = []
for name, df in data_frames.items():
if df.empty:
continue
insight = f"Dataset: {name}\n"
insight += f"Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
insight += f"Columns: {', '.join(df.columns.tolist())}\n"
# Basic statistics for numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) > 0:
insight += "Numeric column statistics:\n"
for col in numeric_cols[:5]: # Limit to first 5 columns
stats = df[col].describe()
insight += f" {col}: min={stats['min']:.2f}, max={stats['max']:.2f}, mean={stats['mean']:.2f}, median={df[col].median():.2f}\n"
if len(numeric_cols) > 5:
insight += f" ... and {len(numeric_cols) - 5} more numeric columns\n"
# Date range for datetime columns
date_cols = df.select_dtypes(include=['datetime']).columns
if len(date_cols) > 0:
insight += "Date range:\n"
for col in date_cols:
if df[col].notna().any():
min_date = df[col].min()
max_date = df[col].max()
insight += f" {col}: {min_date} to {max_date}\n"
# Categorical column value counts
cat_cols = df.select_dtypes(include=['object', 'category']).columns
if len(cat_cols) > 0:
insight += "Categorical columns:\n"
for col in cat_cols[:3]: # Limit to first 3 columns
# Get top 3 values
top_values = df[col].value_counts().head(3)
vals_str = ", ".join([f"{val} ({count})" for val, count in top_values.items()])
insight += f" {col}: {df[col].nunique()} unique values. Top values: {vals_str}\n"
if len(cat_cols) > 3:
insight += f" ... and {len(cat_cols) - 3} more categorical columns\n"
# Missing values
missing = df.isna().sum()
if missing.sum() > 0:
cols_with_missing = missing[missing > 0]
insight += "Missing values:\n"
for col, count in cols_with_missing.items():
pct = 100 * count / len(df)
insight += f" {col}: {count} missing values ({pct:.1f}%)\n"
insights.append(insight)
return "\n\n".join(insights)
def _generate_visualizations(self, data_frames: Dict[str, pd.DataFrame], analysis_type: str,
workflow_id: str, task: str) -> List[Dict[str, Any]]:
"""
Generate appropriate visualizations based on data and analysis type.
Args:
data_frames: Dictionary of DataFrames to visualize
analysis_type: Type of analysis being performed
workflow_id: Workflow ID
task: Original task description
Returns:
List of visualization document objects
"""
documents = []
for name, df in data_frames.items():
if df.empty or df.shape[0] < 2:
continue # Skip empty or single-row DataFrames
# Generate different visualizations based on the analysis type
if analysis_type == "statistical":
viz_docs = self._create_statistical_visualizations(df, name, workflow_id)
documents.extend(viz_docs)
elif analysis_type == "trend":
viz_docs = self._create_trend_visualizations(df, name, workflow_id)
documents.extend(viz_docs)
elif analysis_type == "comparative":
viz_docs = self._create_comparative_visualizations(df, name, workflow_id)
documents.extend(viz_docs)
elif analysis_type == "predictive":
viz_docs = self._create_predictive_visualizations(df, name, workflow_id)
documents.extend(viz_docs)
elif analysis_type == "clustering":
viz_docs = self._create_clustering_visualizations(df, name, workflow_id)
documents.extend(viz_docs)
else: # general analysis
viz_docs = self._create_general_visualizations(df, name, workflow_id)
documents.extend(viz_docs)
return documents
def _create_statistical_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
"""Create statistical visualizations for a DataFrame"""
documents = []
# 1. Distribution/Histogram plots for numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns[:5] # Limit to first 5
if len(numeric_cols) > 0:
plt.figure(figsize=(12, 8))
for i, col in enumerate(numeric_cols, 1):
plt.subplot(len(numeric_cols), 1, i)
sns.histplot(df[col].dropna(), kde=True)
plt.title(f'Distribution of {col}')
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_stat_dist_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Statistical Distributions - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# 2. Box plots for numeric columns
if len(numeric_cols) > 0:
plt.figure(figsize=(12, 8))
sns.boxplot(data=df[numeric_cols])
plt.title(f'Box Plots of Numeric Variables in {name}')
plt.xticks(rotation=45)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_stat_box_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Box Plots - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# 3. Correlation heatmap for numeric columns
if len(numeric_cols) >= 2:
plt.figure(figsize=(10, 8))
corr = df[numeric_cols].corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
plt.title(f'Correlation Heatmap - {name}')
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_stat_corr_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Correlation Heatmap - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
return documents
def _create_trend_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
"""Create trend visualizations for a DataFrame"""
documents = []
# Check for date/time columns
date_cols = df.select_dtypes(include=['datetime']).columns
# If we have date columns, create time series plots
if len(date_cols) > 0:
date_col = date_cols[0] # Use the first date column
# Find numeric columns to plot against the date
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
if len(numeric_cols) > 0:
plt.figure(figsize=(12, 8))
for i, col in enumerate(numeric_cols, 1):
plt.subplot(len(numeric_cols), 1, i)
plt.plot(df[date_col], df[col])
plt.title(f'Trend of {col} over time')
plt.xticks(rotation=45)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_trend_time_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Time Series Trends - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# If no date columns found, find another column that might represent sequence/order
else:
# Look for columns with sequential numbers
potential_sequence_cols = []
for col in df.select_dtypes(include=['number']).columns:
values = df[col].dropna().values
if len(values) >= 5:
# Check if values are mostly sequential
diffs = np.diff(sorted(values))
if np.all(diffs > 0) and np.std(diffs) / np.mean(diffs) < 0.5:
potential_sequence_cols.append(col)
# Use first potential sequence column or first numeric column
numeric_cols = df.select_dtypes(include=['number']).columns
if len(potential_sequence_cols) > 0 and len(numeric_cols) > 1:
sequence_col = potential_sequence_cols[0]
# Find other numeric columns to plot against the sequence
plot_cols = [col for col in numeric_cols if col != sequence_col][:2]
plt.figure(figsize=(12, 6))
for col in plot_cols:
plt.plot(df[sequence_col], df[col], marker='o', label=col)
plt.title(f'Trend by {sequence_col} - {name}')
plt.legend()
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_trend_seq_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Sequential Trends - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# Moving average visualization if we have enough data points
if len(df) > 10:
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
if len(numeric_cols) > 0:
plt.figure(figsize=(12, 6))
for col in numeric_cols:
# Sort data if we have a date column
if len(date_cols) > 0:
sorted_df = df.sort_values(by=date_cols[0])
else:
sorted_df = df
# Calculate moving average (window size 3)
values = sorted_df[col].values
window_size = min(3, len(values) - 1)
if window_size > 0:
moving_avg = np.convolve(values, np.ones(window_size)/window_size, mode='valid')
# Plot original and moving average
plt.plot(values, label=f'{col} (Original)')
plt.plot(np.arange(window_size-1, len(values)), moving_avg, label=f'{col} (Moving Avg)')
plt.title(f'Moving Average Trends - {name}')
plt.legend()
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_trend_mavg_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Moving Average Trends - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
return documents
def _create_comparative_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
"""Create comparative visualizations for a DataFrame"""
documents = []
# 1. Look for categorical columns to use for grouping
cat_cols = df.select_dtypes(include=['object', 'category']).columns
if len(cat_cols) > 0:
# Use the first categorical column with reasonable number of unique values
groupby_col = None
for col in cat_cols:
unique_count = df[col].nunique()
if 2 <= unique_count <= 10: # Reasonable number of categories
groupby_col = col
break
if groupby_col:
# Find numeric columns to compare across groups
numeric_cols = df.select_dtypes(include=['number']).columns[:3] # Limit to first 3
if len(numeric_cols) > 0:
# 1. Bar chart comparing means
plt.figure(figsize=(12, 6))
mean_by_group = df.groupby(groupby_col)[numeric_cols].mean()
mean_by_group.plot(kind='bar')
plt.title(f'Mean Comparison by {groupby_col} - {name}')
plt.xticks(rotation=45)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_comp_bar_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Mean Comparison by {groupby_col} - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# 2. Box plots for comparing distributions
plt.figure(figsize=(12, 8))
for i, col in enumerate(numeric_cols, 1):
plt.subplot(len(numeric_cols), 1, i)
sns.boxplot(x=groupby_col, y=col, data=df)
plt.title(f'Distribution of {col} by {groupby_col}')
plt.xticks(rotation=45)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_comp_box_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Distribution Comparison - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# 3. Scatter plot comparing two numeric variables
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) >= 2:
plt.figure(figsize=(10, 8))
# Use first two numeric columns
x_col, y_col = numeric_cols[0], numeric_cols[1]
scatter = plt.scatter(df[x_col], df[y_col])
plt.title(f'Comparison of {x_col} vs {y_col} - {name}')
plt.xlabel(x_col)
plt.ylabel(y_col)
# Add color if we have a categorical column
if len(cat_cols) > 0:
groupby_col = cat_cols[0]
if df[groupby_col].nunique() <= 10: # Reasonable number of categories
plt.figure(figsize=(10, 8))
scatter = plt.scatter(df[x_col], df[y_col], c=pd.factorize(df[groupby_col])[0], cmap='viridis')
plt.title(f'Comparison of {x_col} vs {y_col} by {groupby_col} - {name}')
plt.xlabel(x_col)
plt.ylabel(y_col)
legend1 = plt.legend(scatter.legend_elements()[0], df[groupby_col].unique(), title=groupby_col)
plt.gca().add_artist(legend1)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_comp_scatter_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Variable Comparison - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
return documents
def _create_predictive_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
"""Create predictive visualizations for a DataFrame"""
documents = []
# Check for date/time columns for time series prediction
date_cols = df.select_dtypes(include=['datetime']).columns
if len(date_cols) > 0:
date_col = date_cols[0] # Use the first date column
# Sort by date
df_sorted = df.sort_values(by=date_col)
# Find numeric columns to predict
numeric_cols = df.select_dtypes(include=['number']).columns[:2] # Limit to first 2
if len(numeric_cols) > 0:
plt.figure(figsize=(12, 8))
for i, col in enumerate(numeric_cols, 1):
plt.subplot(len(numeric_cols), 1, i)
# Get values and dates
values = df_sorted[col].values
dates = df_sorted[date_col].values
# Need minimum number of points for meaningful prediction
if len(values) >= 5:
# Use basic linear regression for prediction
# Convert dates to numeric values for regression
date_nums = np.array([(d - dates[0]).total_seconds() for d in dates])
date_nums = date_nums / np.max(date_nums) # Normalize
# Remove NaNs
mask = ~np.isnan(values)
if np.sum(mask) >= 3: # Need at least 3 points
x = date_nums[mask].reshape(-1, 1)
y = values[mask]
# Fit linear regression
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(x, y)
# Predict on original range
y_pred = model.predict(x)
# Extend for prediction
x_extended = np.linspace(0, 1.2, 100).reshape(-1, 1)
y_extended = model.predict(x_extended)
# Convert x_extended back to dates for plotting
max_seconds = np.max([(d - dates[0]).total_seconds() for d in dates])
future_seconds = x_extended.flatten() * max_seconds
future_dates = [dates[0] + pd.Timedelta(seconds=s) for s in future_seconds]
# Plot
plt.plot(dates, values, 'o-', label='Actual')
plt.plot(future_dates, y_extended, '--', label='Predicted')
plt.axvline(x=dates[-1], color='r', linestyle=':', label='Current')
plt.title(f'Prediction for {col}')
plt.xticks(rotation=45)
plt.legend()
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_pred_time_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Time Series Prediction - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# Regression prediction (feature vs target)
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) >= 2:
plt.figure(figsize=(10, 8))
# Use first two numeric columns as feature and target
x_col, y_col = numeric_cols[0], numeric_cols[1]
# Remove NaNs
df_clean = df[[x_col, y_col]].dropna()
if len(df_clean) >= 5: # Need minimum points for regression
x = df_clean[x_col].values.reshape(-1, 1)
y = df_clean[y_col].values
# Fit linear regression
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(x, y)
# Generate predictions
x_range = np.linspace(df_clean[x_col].min(), df_clean[x_col].max() * 1.1, 100).reshape(-1, 1)
y_pred = model.predict(x_range)
# Plot
plt.scatter(df_clean[x_col], df_clean[y_col], label='Data Points')
plt.plot(x_range, y_pred, 'r--', label=f'Predicted {y_col}')
plt.title(f'Regression Prediction: {y_col} based on {x_col} - {name}')
plt.xlabel(x_col)
plt.ylabel(y_col)
plt.legend()
# Add regression equation
slope = model.coef_[0]
intercept = model.intercept_
plt.text(0.05, 0.95, f'{y_col} = {slope:.2f} * {x_col} + {intercept:.2f}',
transform=plt.gca().transAxes, fontsize=10,
verticalalignment='top')
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_pred_reg_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Regression Prediction - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
return documents
def _create_clustering_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
"""Create clustering visualizations for a DataFrame"""
documents = []
# Need numeric columns for clustering
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) >= 2:
# Select two numeric columns for 2D visualization
cols = numeric_cols[:2]
# Remove NaNs
df_clean = df[cols].dropna()
if len(df_clean) >= 5: # Need minimum points for clustering
# Normalize data
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
data_scaled = scaler.fit_transform(df_clean)
# Apply K-means clustering
from sklearn.cluster import KMeans
# Determine number of clusters (2-5 based on data size)
n_clusters = min(max(2, len(df_clean) // 10), 5)
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
clusters = kmeans.fit_predict(data_scaled)
# Add cluster labels to DataFrame
df_clean['Cluster'] = clusters
# Create scatter plot with clusters
plt.figure(figsize=(10, 8))
# Plot clusters
scatter = plt.scatter(df_clean[cols[0]], df_clean[cols[1]], c=df_clean['Cluster'], cmap='viridis')
# Plot centroids
centroids = scaler.inverse_transform(kmeans.cluster_centers_)
plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=200, c='red', label='Centroids')
plt.title(f'K-means Clustering ({n_clusters} clusters) - {name}')
plt.xlabel(cols[0])
plt.ylabel(cols[1])
plt.legend(*scatter.legend_elements(), title="Clusters")
plt.legend()
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_clust_kmeans_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"K-means Clustering - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# If we have more than 2 numeric columns, also create a PCA visualization
if len(numeric_cols) > 2:
from sklearn.decomposition import PCA
# Select more columns for PCA
pca_cols = numeric_cols[:min(len(numeric_cols), 5)]
# Remove NaNs
df_pca = df[pca_cols].dropna()
if len(df_pca) >= 5:
# Normalize data
pca_data = StandardScaler().fit_transform(df_pca)
# Apply PCA to reduce to 2 dimensions
pca = PCA(n_components=2)
principal_components = pca.fit_transform(pca_data)
# Create DataFrame with principal components
pca_df = pd.DataFrame(data=principal_components, columns=['PC1', 'PC2'])
# Apply clustering to PCA results
clusters = KMeans(n_clusters=n_clusters, random_state=42).fit_predict(pca_df)
pca_df['Cluster'] = clusters
# Create scatter plot
plt.figure(figsize=(10, 8))
scatter = plt.scatter(pca_df['PC1'], pca_df['PC2'], c=pca_df['Cluster'], cmap='viridis')
plt.title(f'PCA Clustering ({n_clusters} clusters) - {name}')
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
plt.legend(*scatter.legend_elements(), title="Clusters")
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_clust_pca_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"PCA Clustering - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
return documents
def _create_general_visualizations(self, df: pd.DataFrame, name: str, workflow_id: str) -> List[Dict[str, Any]]:
"""Create general purpose visualizations for a DataFrame"""
documents = []
# 1. Data overview: numeric summary
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) > 0:
# Create a bar chart of means for numeric columns
plt.figure(figsize=(12, 6))
means = df[numeric_cols].mean().sort_values()
means.plot(kind='bar')
plt.title(f'Mean Values of Numeric Variables - {name}')
plt.xticks(rotation=45)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_gen_means_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Numeric Variables Summary - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# 2. Categorical data overview
cat_cols = df.select_dtypes(include=['object', 'category']).columns
if len(cat_cols) > 0:
# Select the first categorical column with reasonable cardinality
for col in cat_cols:
if df[col].nunique() <= 10: # Reasonable number of categories
plt.figure(figsize=(10, 6))
value_counts = df[col].value_counts().sort_values(ascending=False)
value_counts.plot(kind='bar')
plt.title(f'Distribution of {col} - {name}')
plt.xticks(rotation=45)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_gen_cat_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Categorical Distribution - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
break # Only use the first suitable column
# 3. Correlation matrix if we have multiple numeric columns
if len(numeric_cols) >= 2:
plt.figure(figsize=(10, 8))
corr = df[numeric_cols].corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
plt.title(f'Correlation Matrix - {name}')
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_gen_corr_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Correlation Matrix - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
# 4. If we have date columns, show time-based visualization
date_cols = df.select_dtypes(include=['datetime']).columns
if len(date_cols) > 0 and len(numeric_cols) > 0:
date_col = date_cols[0] # Use the first date column
num_col = numeric_cols[0] # Use the first numeric column
plt.figure(figsize=(12, 6))
plt.plot(df[date_col], df[num_col], marker='o')
plt.title(f'{num_col} over Time - {name}')
plt.xticks(rotation=45)
plt.tight_layout()
# Save figure
img_data = self._get_figure_as_base64()
plt.close()
# Create document
doc_id = f"viz_gen_time_{uuid.uuid4()}"
doc = {
"id": doc_id,
"source": {
"type": "generated",
"id": doc_id,
"name": f"Time Series Overview - {name}",
"content_type": "image/png",
"size": len(img_data)
},
"contents": [{
"type": "image",
"data": img_data,
"format": "base64"
}]
}
documents.append(doc)
return documents
def _get_figure_as_base64(self) -> str:
"""Convert current matplotlib figure to base64 string"""
buffer = io.BytesIO()
plt.savefig(buffer, format='png', dpi=self.chart_dpi)
buffer.seek(0)
image_png = buffer.getvalue()
buffer.close()
# Convert to base64
image_base64 = base64.b64encode(image_png).decode('utf-8')
return image_base64
"""
Enhanced _generate_analysis method to better handle text-only analysis.
"""
async def _generate_analysis(self, prompt: str, analysis_type: str) -> str:
"""
Generate analysis based on prompt and analysis type.
Enhanced to handle text-only analysis.
Args:
prompt: The analysis prompt
analysis_type: Type of analysis
Returns:
Generated analysis
"""
if not self.ai_service:
logging.warning("AI service not available for analysis generation")
return f"## Data Analysis ({analysis_type})\n\nUnable to generate analysis: AI service not available."
# Create specialized prompt based on analysis type
system_prompt = self._get_analysis_system_prompt(analysis_type)
# Determine if this is a data-based or text-based analysis
is_data_analysis = "DATA INSIGHTS" in prompt
# Enhance the prompt with analysis-specific instructions
if is_data_analysis:
enhanced_prompt = f"""
Generate a detailed {analysis_type} analysis based on the following data:
{prompt}
Your analysis should include:
1. A summary of the data
2. Key findings and insights
3. Supporting evidence and calculations
4. Clear conclusions
5. Recommendations where appropriate
Format the analysis in Markdown with proper headings, lists, and tables.
"""
else:
# Text-based analysis instructions
enhanced_prompt = f"""
Generate a detailed {analysis_type} analysis of the following text content:
{prompt}
Your analysis should include:
1. A summary of the main themes and topics
2. Key insights and observations
3. Analysis of structure, patterns, and relationships
4. Clear conclusions and interpretations
5. Recommendations or implications where appropriate
Format the analysis in Markdown with proper headings, lists, and tables.
"""
try:
content = await self.ai_service.call_api([
{"role": "system", "content": system_prompt},
{"role": "user", "content": enhanced_prompt}
])
# Ensure there's a title at the top
if not content.strip().startswith("# "):
content = f"# {analysis_type.capitalize()} Analysis\n\n{content}"
return content
except Exception as e:
return f"# {analysis_type.capitalize()} Analysis\n\nError generating analysis: {str(e)}"
def _get_analysis_system_prompt(self, analysis_type: str) -> str:
"""
Get specialized system prompt for specific analysis type.
Enhanced with text analysis capabilities.
Args:
analysis_type: Type of analysis
Returns:
System prompt
"""
base_prompt = self._get_system_prompt()
# Add analysis-specific instructions
if analysis_type == "statistical":
return f"{base_prompt}\n\nFocus on statistical measures including mean, median, mode, variance, and distribution. Identify outliers and unusual data points. Present key statistics in tables where appropriate."
elif analysis_type == "trend":
return f"{base_prompt}\n\nFocus on identifying trends over time, seasonality, and patterns in the data. Look for long-term movements, cyclical patterns, and turning points. Consider rate of change and growth rates."
elif analysis_type == "comparative":
return f"{base_prompt}\n\nFocus on comparing different groups, categories, or time periods. Highlight similarities and differences. Use comparative metrics and relative measures rather than just absolute values."
elif analysis_type == "predictive":
return f"{base_prompt}\n\nFocus on extrapolating trends and patterns to make predictions about future values. Discuss confidence levels and potential factors that could influence outcomes. Be clear about assumptions."
elif analysis_type == "clustering":
return f"{base_prompt}\n\nFocus on identifying natural groupings or segments within the data. Describe the characteristics of each cluster and what distinguishes them. Consider similarities within groups and differences between groups."
elif analysis_type == "textual":
return f"{base_prompt}\n\nFocus on analyzing the text content provided. Identify key themes, topics, and concepts. Analyze sentiment, tone, and perspective. Extract important relationships, arguments, or logical structures. Provide insights into the meaning and implications of the text."
elif analysis_type == "summary":
return f"{base_prompt}\n\nFocus on providing a concise overview of the provided content. Identify the main points, key arguments, and essential information. Distill complex information into clear, digestible insights. Maintain objectivity while highlighting the most important elements."
else:
return base_prompt
def _get_system_prompt(self) -> str:
"""
Get specialized system prompt for analyst agent.
Enhanced to handle text analysis better.
Returns:
System prompt
"""
return f"""
You are {self.name}, a specialized {self.type} agent focused on data and text analysis.
{self.description}
When analyzing data:
1. First, identify the data structure and key variables
2. Look for patterns, trends, and outliers
3. Provide statistical insights and evidence-based conclusions
4. Highlight any important findings clearly
5. Suggest visualizations that would help understand the data
When analyzing text content:
1. Identify key themes, concepts, and topics
2. Extract important patterns and relationships
3. Provide insights into the meaning and implications of the text
4. Identify sentiment, tone, and perspective where relevant
5. Organize findings in a logical, structured way
For CSV data, interpret tables correctly and perform calculations accurately.
For textual data, extract key metrics, themes and relationships.
Respond in a clear, analytical style, and format your findings in a structured report.
"""
def send_analysis_result(self, analysis_content: str, sender_id: str, receiver_id: str,
task_id: str, analysis_data: Dict[str, Any] = None,
context_id: str = None) -> AgentMessage:
"""
Send analysis results using the protocol.
Args:
analysis_content: Analysis content
sender_id: Sender ID
receiver_id: Receiver ID
task_id: Task ID
analysis_data: Additional analysis data
context_id: Context ID
Returns:
Protocol message
"""
return self.protocol.create_result_message(
result_content=analysis_content,
sender_id=sender_id,
receiver_id=receiver_id,
task_id=task_id,
output_data=analysis_data,
result_format=self.result_format,
context_id=context_id
)
def send_error_message(self, error_description: str, sender_id: str, receiver_id: str = None,
error_details: Dict[str, Any] = None, context_id: str = None) -> AgentMessage:
"""
Send error message using the protocol.
Args:
error_description: Error description
sender_id: Sender ID
receiver_id: Receiver ID
error_details: Error details
context_id: Context ID
Returns:
Protocol message
"""
return self.protocol.create_error_message(
error_description=error_description,
sender_id=sender_id,
receiver_id=receiver_id,
error_type="analysis_error",
error_details=error_details,
context_id=context_id
)
def send_document_request_message(self, document_description: str, sender_id: str, receiver_id: str,
filters: Dict[str, Any] = None, context_id: str = None) -> AgentMessage:
"""
Send document request using the protocol.
Args:
document_description: Document description
sender_id: Sender ID
receiver_id: Receiver ID
filters: Document filters
context_id: Context ID
Returns:
Protocol message
"""
return self.protocol.create_document_request_message(
document_description=document_description,
sender_id=sender_id,
receiver_id=receiver_id,
filters=filters,
context_id=context_id
)
# Singleton instance
_analyst_agent = None
def get_analyst_agent():
"""Returns a singleton instance of the data analyst agent"""
global _analyst_agent
if _analyst_agent is None:
_analyst_agent = AnalystAgent()
return _analyst_agent