149 lines
7 KiB
Python
149 lines
7 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
|
|
import logging
|
|
import time
|
|
from typing import Dict, Any
|
|
from modules.datamodels.datamodelChat import ActionResult, ActionDocument
|
|
from modules.datamodels.datamodelDocref import (
|
|
DocumentReferenceList,
|
|
coerceDocumentReferenceList,
|
|
)
|
|
from modules.datamodels.datamodelExtraction import ExtractionOptions, MergeStrategy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def extractContent(self, parameters: Dict[str, Any]) -> ActionResult:
|
|
operationId = None
|
|
try:
|
|
workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
|
|
operationId = f"context_extract_{workflowId}_{int(time.time())}"
|
|
|
|
documentListParam = parameters.get("documentList")
|
|
if not documentListParam:
|
|
return ActionResult.isFailure(error="documentList is required")
|
|
|
|
documentList = coerceDocumentReferenceList(documentListParam)
|
|
if not documentList.references:
|
|
return ActionResult.isFailure(
|
|
error=f"documentList could not be parsed (type={type(documentListParam).__name__}); "
|
|
f"expected DocumentReferenceList, list of strings/dicts, or "
|
|
f"a wrapper dict like {{'documents': [...]}}"
|
|
)
|
|
|
|
# Start progress tracking
|
|
parentOperationId = parameters.get('parentOperationId')
|
|
self.services.chat.progressLogStart(
|
|
operationId,
|
|
"Extracting content from documents",
|
|
"Content Extraction",
|
|
f"Documents: {len(documentList.references)}",
|
|
parentOperationId=parentOperationId
|
|
)
|
|
|
|
# Get ChatDocuments from documentList
|
|
self.services.chat.progressLogUpdate(operationId, 0.2, "Loading documents")
|
|
chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(documentList)
|
|
|
|
if not chatDocuments:
|
|
self.services.chat.progressLogFinish(operationId, False)
|
|
return ActionResult.isFailure(error="No documents found in documentList")
|
|
|
|
logger.info(f"Extracting content from {len(chatDocuments)} documents")
|
|
|
|
# Prepare extraction options
|
|
self.services.chat.progressLogUpdate(operationId, 0.3, "Preparing extraction options")
|
|
extractionOptionsParam = parameters.get("extractionOptions")
|
|
|
|
# Convert dict to ExtractionOptions object if needed, or create defaults
|
|
if extractionOptionsParam:
|
|
if isinstance(extractionOptionsParam, dict):
|
|
# Ensure required fields are present
|
|
if "prompt" not in extractionOptionsParam:
|
|
extractionOptionsParam["prompt"] = "Extract all content from the document"
|
|
if "mergeStrategy" not in extractionOptionsParam:
|
|
extractionOptionsParam["mergeStrategy"] = MergeStrategy(
|
|
mergeType="concatenate",
|
|
groupBy="typeGroup",
|
|
orderBy="id"
|
|
)
|
|
# Convert dict to ExtractionOptions object
|
|
try:
|
|
extractionOptions = ExtractionOptions(**extractionOptionsParam)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create ExtractionOptions from dict: {str(e)}, using defaults")
|
|
extractionOptions = None
|
|
elif isinstance(extractionOptionsParam, ExtractionOptions):
|
|
extractionOptions = extractionOptionsParam
|
|
else:
|
|
# Invalid type, use defaults
|
|
logger.warning(f"Invalid extractionOptions type: {type(extractionOptionsParam)}, using defaults")
|
|
extractionOptions = None
|
|
else:
|
|
extractionOptions = None
|
|
|
|
# If extractionOptions not provided, create defaults
|
|
if not extractionOptions:
|
|
# Default extraction options for pure content extraction (no AI processing)
|
|
extractionOptions = ExtractionOptions(
|
|
prompt="Extract all content from the document",
|
|
mergeStrategy=MergeStrategy(
|
|
mergeType="concatenate",
|
|
groupBy="typeGroup",
|
|
orderBy="id"
|
|
),
|
|
processDocumentsIndividually=True
|
|
)
|
|
|
|
# Call extraction service with hierarchical progress logging
|
|
self.services.chat.progressLogUpdate(operationId, 0.4, "Initiating")
|
|
self.services.chat.progressLogUpdate(operationId, 0.5, f"Extracting content from {len(chatDocuments)} documents")
|
|
# Pass operationId for hierarchical per-document progress logging
|
|
extractedResults = self.services.extraction.extractContent(chatDocuments, extractionOptions, operationId=operationId)
|
|
|
|
# Build ActionDocuments from ContentExtracted results
|
|
self.services.chat.progressLogUpdate(operationId, 0.8, "Building result documents")
|
|
actionDocuments = []
|
|
# Map extracted results back to original documents by index (results are in same order)
|
|
for i, extracted in enumerate(extractedResults):
|
|
# Get original document name if available
|
|
originalDoc = chatDocuments[i] if i < len(chatDocuments) else None
|
|
if originalDoc and hasattr(originalDoc, 'fileName') and originalDoc.fileName:
|
|
# Use original filename with "extracted_" prefix
|
|
baseName = originalDoc.fileName.rsplit('.', 1)[0] if '.' in originalDoc.fileName else originalDoc.fileName
|
|
documentName = f"{baseName}_extracted_{extracted.id}.json"
|
|
else:
|
|
# Fallback to generic name with index
|
|
documentName = f"document_{i+1:03d}_extracted_{extracted.id}.json"
|
|
|
|
# Store ContentExtracted object in ActionDocument.documentData
|
|
validationMetadata = {
|
|
"actionType": "context.extractContent",
|
|
"documentIndex": i,
|
|
"extractedId": extracted.id,
|
|
"partCount": len(extracted.parts) if extracted.parts else 0,
|
|
"originalFileName": originalDoc.fileName if originalDoc and hasattr(originalDoc, 'fileName') else None
|
|
}
|
|
actionDoc = ActionDocument(
|
|
documentName=documentName,
|
|
documentData=extracted, # ContentExtracted object
|
|
mimeType="application/json",
|
|
validationMetadata=validationMetadata
|
|
)
|
|
actionDocuments.append(actionDoc)
|
|
|
|
self.services.chat.progressLogFinish(operationId, True)
|
|
|
|
return ActionResult.isSuccess(documents=actionDocuments)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in content extraction: {str(e)}")
|
|
|
|
try:
|
|
if operationId:
|
|
self.services.chat.progressLogFinish(operationId, False)
|
|
except Exception:
|
|
pass
|
|
|
|
return ActionResult.isFailure(error=str(e))
|
|
|