415 lines
19 KiB
Python
415 lines
19 KiB
Python
"""
|
|
AI processing method module.
|
|
Handles direct AI calls for any type of task.
|
|
"""
|
|
|
|
import time
|
|
import logging
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import datetime, UTC
|
|
|
|
from modules.workflows.methods.methodBase import MethodBase, action
|
|
from modules.datamodels.datamodelChat import ActionResult
|
|
from modules.datamodels.datamodelAi import AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum, ModelCapabilitiesEnum
|
|
from modules.datamodels.datamodelChat import ChatDocument
|
|
from modules.aicore.aicorePluginTavily import WebResearchRequest
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MethodAi(MethodBase):
|
|
"""AI processing methods."""
|
|
|
|
def __init__(self, services):
|
|
super().__init__(services)
|
|
self.name = "ai"
|
|
self.description = "AI processing methods"
|
|
|
|
def _format_timestamp_for_filename(self) -> str:
|
|
"""Format current timestamp as YYYYMMDD-hhmmss for filenames."""
|
|
return datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
|
|
|
|
@action
|
|
async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
|
"""
|
|
GENERAL:
|
|
- Purpose: Process a user prompt with optional unlimited input documents to produce one or many output documents of the SAME format.
|
|
- Input requirements: aiPrompt (required); optional documentList.
|
|
- Output format: Exactly one file format to select. For multiple output file formats to do different calls.
|
|
|
|
Parameters:
|
|
- aiPrompt (str, required): Instruction for the AI.
|
|
- documentList (list, optional): Document reference(s) for context.
|
|
- resultType (str, optional): Output file extension - only one extension allowed (e.g. txt, json, md, csv, xml, html, pdf, docx, xlsx, png, ...). Default: txt.
|
|
- processingMode (str, optional): basic | advanced | detailed. Default: basic.
|
|
- includeMetadata (bool, optional): Include metadata when available. Default: True.
|
|
- operationType (str, optional): general | plan | analyse | generate | webResearch | imageAnalyse | imageGenerate. Default: general.
|
|
- priority (str, optional): speed | quality | cost | balanced. Default: balanced.
|
|
- maxCost (float, optional): Cost limit.
|
|
- maxProcessingTime (int, optional): Time limit in seconds.
|
|
- operationTypes (list, optional): Capability tags (e.g., text, chat, reasoning, analysis, image, vision, web, search).
|
|
"""
|
|
try:
|
|
# Init progress logger
|
|
operationId = f"ai_process_{self.services.currentWorkflow.id}_{int(time.time())}"
|
|
|
|
# Start progress tracking
|
|
self.services.workflow.progressLogStart(
|
|
operationId,
|
|
"Generate",
|
|
"AI Processing",
|
|
f"Format: {parameters.get('resultType', 'txt')}"
|
|
)
|
|
|
|
# Debug logging to see what parameters are received
|
|
logger.info(f"MethodAi.process received parameters: {parameters}")
|
|
logger.info(f"Parameters type: {type(parameters)}")
|
|
logger.info(f"Parameters keys: {list(parameters.keys()) if isinstance(parameters, dict) else 'Not a dict'}")
|
|
|
|
aiPrompt = parameters.get("aiPrompt")
|
|
logger.info(f"aiPrompt extracted: '{aiPrompt}' (type: {type(aiPrompt)})")
|
|
|
|
# Update progress - preparing parameters
|
|
self.services.workflow.progressLogUpdate(operationId, 0.2, "Preparing parameters")
|
|
|
|
documentList = parameters.get("documentList", [])
|
|
if isinstance(documentList, str):
|
|
documentList = [documentList]
|
|
resultType = parameters.get("resultType", "txt")
|
|
processingModeStr = parameters.get("processingMode", "basic")
|
|
includeMetadata = parameters.get("includeMetadata", True)
|
|
operationTypeStr = parameters.get("operationType", "general")
|
|
priorityStr = parameters.get("priority", "balanced")
|
|
maxCost = parameters.get("maxCost")
|
|
maxProcessingTime = parameters.get("maxProcessingTime")
|
|
operationTypes = parameters.get("operationTypes")
|
|
requiredTags = parameters.get("requiredTags", [])
|
|
|
|
# Map string parameters to enums
|
|
operationTypeMapping = {
|
|
"general": OperationTypeEnum.GENERAL,
|
|
"plan": OperationTypeEnum.PLAN,
|
|
"analyse": OperationTypeEnum.ANALYSE,
|
|
"generate": OperationTypeEnum.GENERATE,
|
|
"webResearch": OperationTypeEnum.WEB_RESEARCH,
|
|
"imageAnalyse": OperationTypeEnum.IMAGE_ANALYSE,
|
|
"imageGenerate": OperationTypeEnum.IMAGE_GENERATE
|
|
}
|
|
operationType = operationTypeMapping.get(operationTypeStr, OperationTypeEnum.GENERAL)
|
|
|
|
priorityMapping = {
|
|
"speed": PriorityEnum.SPEED,
|
|
"quality": PriorityEnum.QUALITY,
|
|
"cost": PriorityEnum.COST,
|
|
"balanced": PriorityEnum.BALANCED
|
|
}
|
|
priority = priorityMapping.get(priorityStr, PriorityEnum.BALANCED)
|
|
|
|
processingModeMapping = {
|
|
"basic": ProcessingModeEnum.BASIC,
|
|
"advanced": ProcessingModeEnum.ADVANCED,
|
|
"detailed": ProcessingModeEnum.DETAILED
|
|
}
|
|
processingMode = processingModeMapping.get(processingModeStr, ProcessingModeEnum.BASIC)
|
|
|
|
# Map requiredTags from strings to ModelCapabilitiesEnum
|
|
if requiredTags and isinstance(requiredTags, list):
|
|
tagMapping = {
|
|
"text": ModelCapabilitiesEnum.TEXT_GENERATION,
|
|
"chat": ModelCapabilitiesEnum.CHAT,
|
|
"reasoning": ModelCapabilitiesEnum.REASONING,
|
|
"analysis": ModelCapabilitiesEnum.ANALYSIS,
|
|
"image": ModelCapabilitiesEnum.VISION,
|
|
"vision": ModelCapabilitiesEnum.VISION,
|
|
"web": ModelCapabilitiesEnum.WEB_SEARCH,
|
|
"search": ModelCapabilitiesEnum.WEB_SEARCH
|
|
}
|
|
requiredTags = [tagMapping.get(tag, tag) for tag in requiredTags if isinstance(tag, str)]
|
|
|
|
if not aiPrompt:
|
|
logger.error(f"aiPrompt is missing or empty. Parameters: {parameters}")
|
|
return ActionResult.isFailure(
|
|
error="AI prompt is required"
|
|
)
|
|
|
|
# Determine output extension and default MIME type without duplicating service logic
|
|
normalized_result_type = (str(resultType).strip().lstrip('.').lower() or "txt")
|
|
output_extension = f".{normalized_result_type}"
|
|
output_mime_type = "application/octet-stream" # Prefer service-provided mimeType when available
|
|
logger.info(f"Using result type: {resultType} -> {output_extension}")
|
|
|
|
# Update progress - preparing documents
|
|
self.services.workflow.progressLogUpdate(operationId, 0.3, "Preparing documents")
|
|
|
|
# Get ChatDocuments for AI service - let AI service handle all document processing
|
|
chatDocuments = []
|
|
if documentList:
|
|
chatDocuments = self.services.workflow.getChatDocumentsFromDocumentList(documentList)
|
|
if chatDocuments:
|
|
logger.info(f"Prepared {len(chatDocuments)} documents for AI processing")
|
|
|
|
# Update progress - preparing AI call
|
|
self.services.workflow.progressLogUpdate(operationId, 0.4, "Preparing AI call")
|
|
|
|
# Build options and delegate document handling to AI/Extraction/Generation services
|
|
output_format = output_extension.replace('.', '') or 'txt'
|
|
options = AiCallOptions(
|
|
operationType=operationType,
|
|
priority=priority,
|
|
compressPrompt=processingMode != ProcessingModeEnum.DETAILED,
|
|
compressContext=True,
|
|
processDocumentsIndividually=True,
|
|
processingMode=processingMode,
|
|
resultFormat=output_format,
|
|
maxCost=maxCost,
|
|
maxProcessingTime=maxProcessingTime,
|
|
capabilities=requiredTags if requiredTags else None
|
|
)
|
|
|
|
# Update progress - calling AI
|
|
self.services.workflow.progressLogUpdate(operationId, 0.6, "Calling AI")
|
|
|
|
result = await self.services.ai.callAiDocuments(
|
|
prompt=aiPrompt, # Use original prompt, let unified generation handle prompt building
|
|
documents=chatDocuments if chatDocuments else None,
|
|
options=options,
|
|
outputFormat=output_format
|
|
)
|
|
|
|
# Update progress - processing result
|
|
self.services.workflow.progressLogUpdate(operationId, 0.8, "Processing result")
|
|
|
|
from modules.datamodels.datamodelChat import ActionDocument
|
|
|
|
if isinstance(result, dict) and isinstance(result.get("documents"), list):
|
|
action_documents = []
|
|
for d in result["documents"]:
|
|
action_documents.append(ActionDocument(
|
|
documentName=d.get("documentName"),
|
|
documentData=d.get("documentData"),
|
|
mimeType=d.get("mimeType") or output_mime_type
|
|
))
|
|
|
|
# Complete progress tracking
|
|
self.services.workflow.progressLogFinish(operationId, True)
|
|
|
|
return ActionResult.isSuccess(documents=action_documents)
|
|
|
|
extension = output_extension.lstrip('.')
|
|
meaningful_name = self._generateMeaningfulFileName(
|
|
base_name="ai",
|
|
extension=extension,
|
|
action_name="result"
|
|
)
|
|
action_document = ActionDocument(
|
|
documentName=meaningful_name,
|
|
documentData=result,
|
|
mimeType=output_mime_type
|
|
)
|
|
|
|
# Complete progress tracking
|
|
self.services.workflow.progressLogFinish(operationId, True)
|
|
|
|
return ActionResult.isSuccess(documents=[action_document])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in AI processing: {str(e)}")
|
|
|
|
# Complete progress tracking with failure
|
|
try:
|
|
self.services.workflow.progressLogFinish(operationId, False)
|
|
except:
|
|
pass # Don't fail on progress logging errors
|
|
|
|
return ActionResult.isFailure(
|
|
error=str(e)
|
|
)
|
|
|
|
@action
|
|
async def webResearch(self, parameters: Dict[str, Any]) -> ActionResult:
|
|
"""
|
|
GENERAL:
|
|
- Purpose: Web research and information gathering with basic analysis and sources.
|
|
- Input requirements: user_prompt (required); optional urls, max_results, max_pages, search_depth, extract_depth, pages_search_depth, country, time_range, topic, language.
|
|
- Output format: JSON with results and sources.
|
|
|
|
Parameters:
|
|
- user_prompt (str, required): Research question or topic.
|
|
- urls (list, optional): Specific URLs to crawl.
|
|
- max_results (int, optional): Max search results. Default: 5.
|
|
- max_pages (int, optional): Max pages to crawl per site. Default: 5.
|
|
- search_depth (str, optional): basic | advanced. Default: basic.
|
|
- extract_depth (str, optional): basic | advanced. Default: advanced.
|
|
- pages_search_depth (int, optional): Crawl depth level. Default: 2.
|
|
- country (str, optional): Full English country name (ISO-3166; map codes via pycountry/i18n-iso-countries).
|
|
- time_range (str, optional): d | w | m | y.
|
|
- topic (str, optional): general | news | academic.
|
|
- language (str, optional): Language code (e.g., de, en, fr).
|
|
"""
|
|
try:
|
|
user_prompt = parameters.get("user_prompt")
|
|
urls = parameters.get("urls")
|
|
max_results = parameters.get("max_results", 5)
|
|
max_pages = parameters.get("max_pages", 5)
|
|
search_depth = parameters.get("search_depth", "basic")
|
|
extract_depth = parameters.get("extract_depth", "advanced")
|
|
pages_search_depth = parameters.get("pages_search_depth", 2)
|
|
country = parameters.get("country")
|
|
time_range = parameters.get("time_range")
|
|
topic = parameters.get("topic")
|
|
language = parameters.get("language")
|
|
|
|
if not user_prompt:
|
|
return ActionResult.isFailure(
|
|
error="Search query is required"
|
|
)
|
|
|
|
# Build WebResearchRequest (simplified dataclass)
|
|
request = WebResearchRequest(
|
|
user_prompt=user_prompt,
|
|
urls=urls,
|
|
max_results=max_results,
|
|
max_pages=max_pages,
|
|
search_depth=search_depth,
|
|
extract_depth=extract_depth,
|
|
country=country,
|
|
time_range=time_range,
|
|
topic=topic,
|
|
language=language
|
|
)
|
|
|
|
# Call web research service
|
|
logger.info(f"Performing comprehensive web research for: {user_prompt}")
|
|
logger.info(f"Max results: {max_results}, Max pages: {max_pages}")
|
|
if urls:
|
|
logger.info(f"Using provided URLs: {len(urls)}")
|
|
|
|
result = await self.services.ai.webResearch(request)
|
|
|
|
if not result.success:
|
|
return ActionResult.isFailure(error=result.error)
|
|
|
|
# Convert WebResearchResult to ActionResult format
|
|
documents = []
|
|
for doc in result.documents:
|
|
documents.append({
|
|
"documentName": doc.documentName,
|
|
"documentData": {
|
|
"user_prompt": doc.documentData.user_prompt,
|
|
"websites_analyzed": doc.documentData.websites_analyzed,
|
|
"additional_links_found": doc.documentData.additional_links_found,
|
|
"analysis_result": doc.documentData.analysis_result,
|
|
"sources": [{"title": s.title, "url": str(s.url)} for s in doc.documentData.sources],
|
|
"additional_links": doc.documentData.additional_links,
|
|
"debug_info": doc.documentData.debug_info
|
|
},
|
|
"mimeType": doc.mimeType
|
|
})
|
|
|
|
# Return result in the standard ActionResult format
|
|
return ActionResult.isSuccess(
|
|
documents=documents
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in web research: {str(e)}")
|
|
return ActionResult.isFailure(
|
|
error=str(e)
|
|
)
|
|
|
|
def _mergeDataChunks(self, chunks: List[str], resultType: str, mimeType: str) -> str:
|
|
"""Intelligently merge data chunks using strategies based on content type"""
|
|
try:
|
|
if resultType == "json":
|
|
return self._mergeJsonChunks(chunks)
|
|
elif resultType in ["csv", "table"]:
|
|
return self._mergeTableChunks(chunks)
|
|
elif resultType in ["txt", "md", "text"]:
|
|
return self._mergeTextChunks(chunks)
|
|
else:
|
|
# Default: simple concatenation
|
|
return "\n".join(str(chunk) for chunk in chunks)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to merge chunks intelligently: {str(e)}, using simple concatenation")
|
|
return "\n".join(str(chunk) for chunk in chunks)
|
|
|
|
def _mergeJsonChunks(self, chunks: List[str]) -> str:
|
|
"""Merge JSON chunks intelligently"""
|
|
import json
|
|
|
|
merged_data = []
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
if isinstance(chunk, str):
|
|
chunk_data = json.loads(chunk)
|
|
else:
|
|
chunk_data = chunk
|
|
|
|
if isinstance(chunk_data, list):
|
|
merged_data.extend(chunk_data)
|
|
elif isinstance(chunk_data, dict):
|
|
# For objects, merge by combining keys
|
|
if not merged_data:
|
|
merged_data = chunk_data
|
|
else:
|
|
if isinstance(merged_data, dict):
|
|
merged_data.update(chunk_data)
|
|
else:
|
|
merged_data.append(chunk_data)
|
|
else:
|
|
merged_data.append(chunk_data)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse chunk {i}: {str(e)}")
|
|
# Add as string if JSON parsing fails
|
|
merged_data.append(str(chunk))
|
|
|
|
return json.dumps(merged_data, indent=2)
|
|
|
|
def _mergeTableChunks(self, chunks: List[str]) -> str:
|
|
"""Merge table chunks (CSV) intelligently"""
|
|
import csv
|
|
import io
|
|
|
|
merged_rows = []
|
|
headers = None
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
# Parse CSV chunk
|
|
reader = csv.reader(io.StringIO(str(chunk)))
|
|
rows = list(reader)
|
|
|
|
if not rows:
|
|
continue
|
|
|
|
# First chunk: capture headers
|
|
if i == 0:
|
|
headers = rows[0] if rows else []
|
|
merged_rows.extend(rows)
|
|
else:
|
|
# Subsequent chunks: skip header if it matches
|
|
if rows and rows[0] == headers:
|
|
merged_rows.extend(rows[1:]) # Skip duplicate header
|
|
else:
|
|
merged_rows.extend(rows)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse table chunk {i}: {str(e)}")
|
|
# Add as raw text if CSV parsing fails
|
|
merged_rows.append([f"Raw chunk {i}: {str(chunk)[:100]}..."])
|
|
|
|
# Convert back to CSV
|
|
output = io.StringIO()
|
|
writer = csv.writer(output)
|
|
writer.writerows(merged_rows)
|
|
return output.getvalue()
|
|
|
|
def _mergeTextChunks(self, chunks: List[str]) -> str:
|
|
"""Merge text chunks intelligently"""
|
|
# Simple concatenation with proper spacing
|
|
merged = []
|
|
for chunk in chunks:
|
|
chunk_str = str(chunk).strip()
|
|
if chunk_str:
|
|
merged.append(chunk_str)
|
|
|
|
return "\n\n".join(merged) # Double newline between chunks for readability
|