gateway/modules/aichat/serviceExtraction/subRegistry.py
2026-01-22 21:11:25 +01:00

208 lines
8 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
from typing import Any, Dict, Optional
import logging
from modules.datamodels.datamodelExtraction import ContentPart
logger = logging.getLogger(__name__)
class Extractor:
"""
Base class for all document extractors.
Each extractor should implement:
- detect(): Check if this extractor can handle the given file
- extract(): Extract content from the file
- getSupportedExtensions(): Return supported file extensions
- getSupportedMimeTypes(): Return supported MIME types
"""
def detect(self, fileName: str, mimeType: str, headBytes: bytes) -> bool:
"""Check if this extractor can handle the given file."""
return False
def extract(self, fileBytes: bytes, context: Dict[str, Any]) -> list[ContentPart]:
"""Extract content from the file bytes."""
raise NotImplementedError
def getSupportedExtensions(self) -> list[str]:
"""Return list of supported file extensions (including dots)."""
return []
def getSupportedMimeTypes(self) -> list[str]:
"""Return list of supported MIME types."""
return []
class Chunker:
def chunk(self, part: ContentPart, options: Dict[str, Any]) -> list[Dict[str, Any]]:
return []
class ExtractorRegistry:
def __init__(self):
self._map: Dict[str, Extractor] = {}
self._fallback: Optional[Extractor] = None
self._auto_discover_extractors()
def _auto_discover_extractors(self):
"""Auto-discover and register all extractors from the extractors directory."""
try:
import os
import importlib
from pathlib import Path
# Get the extractors directory
current_dir = Path(__file__).parent
extractors_dir = current_dir / "extractors"
if not extractors_dir.exists():
logger.error(f"Extractors directory not found: {extractors_dir}")
return
# Import all extractor modules
extractor_modules = []
for file_path in extractors_dir.glob("extractor*.py"):
if file_path.name == "__init__.py":
continue
module_name = file_path.stem
try:
# Import the module
module = importlib.import_module(f".{module_name}", package="modules.aichat.serviceExtraction.extractors")
# Find all extractor classes in the module
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (isinstance(attr, type) and
issubclass(attr, Extractor) and
attr != Extractor and
not attr_name.startswith('_')):
# Create instance and auto-register
extractor_instance = attr()
self._auto_register_extractor(extractor_instance)
extractor_modules.append(attr_name)
except Exception as e:
logger.warning(f"Failed to import {module_name}: {str(e)}")
continue
# Set fallback extractor
try:
from .extractors.extractorBinary import BinaryExtractor
self.setFallback(BinaryExtractor())
except Exception as e:
logger.warning(f"Failed to set fallback extractor: {str(e)}")
logger.info(f"ExtractorRegistry: Auto-discovered and registered {len(extractor_modules)} extractor classes: {', '.join(extractor_modules)}")
logger.info(f"ExtractorRegistry: Total registered formats: {len(self._map)}")
except Exception as e:
logger.error(f"ExtractorRegistry: Failed to auto-discover extractors: {str(e)}")
import traceback
traceback.print_exc()
def _auto_register_extractor(self, extractor: Extractor):
"""Auto-register an extractor based on its declared supported formats."""
try:
# Register MIME types
mime_types = extractor.getSupportedMimeTypes()
for mime_type in mime_types:
self.register(mime_type, extractor)
# Register file extensions
extensions = extractor.getSupportedExtensions()
for ext in extensions:
# Remove leading dot for registry key
ext_key = ext.lstrip('.')
self.register(ext_key, extractor)
except Exception as e:
logger.error(f"Failed to auto-register {extractor.__class__.__name__}: {str(e)}")
def register(self, key: str, extractor: Extractor):
self._map[key] = extractor
def setFallback(self, extractor: Extractor):
self._fallback = extractor
def resolve(self, mimeType: str, fileName: str) -> Optional[Extractor]:
if mimeType in self._map:
return self._map[mimeType]
# simple extension fallback
if "." in fileName:
ext = fileName.lower().rsplit(".", 1)[-1]
if ext in self._map:
return self._map[ext]
return self._fallback
def getAllSupportedFormats(self) -> Dict[str, Dict[str, list[str]]]:
"""
Get all supported formats from all registered extractors.
Returns:
Dictionary with format information:
{
"extensions": {
"extractor_name": [".ext1", ".ext2", ...]
},
"mime_types": {
"extractor_name": ["mime/type1", "mime/type2", ...]
}
}
"""
formats = {"extensions": {}, "mime_types": {}}
# Get formats from registered extractors
for key, extractor in self._map.items():
if hasattr(extractor, 'getSupportedExtensions'):
extensions = extractor.getSupportedExtensions()
if extensions:
formats["extensions"][key] = extensions
if hasattr(extractor, 'getSupportedMimeTypes'):
mime_types = extractor.getSupportedMimeTypes()
if mime_types:
formats["mime_types"][key] = mime_types
# Add fallback extractor info
if self._fallback and hasattr(self._fallback, 'getSupportedExtensions'):
formats["extensions"]["fallback"] = self._fallback.getSupportedExtensions()
if self._fallback and hasattr(self._fallback, 'getSupportedMimeTypes'):
formats["mime_types"]["fallback"] = self._fallback.getSupportedMimeTypes()
return formats
class ChunkerRegistry:
def __init__(self):
self._map: Dict[str, Chunker] = {}
self._noop = Chunker()
# Register default chunkers
try:
from .chunking.chunkerText import TextChunker
from .chunking.chunkerTable import TableChunker
from .chunking.chunkerStructure import StructureChunker
from .chunking.chunkerImage import ImageChunker
self.register("text", TextChunker())
self.register("table", TableChunker())
self.register("structure", StructureChunker())
self.register("image", ImageChunker())
# Use text chunker for container and binary content
self.register("container", TextChunker())
self.register("binary", TextChunker())
except Exception as e:
logger.error(f"ChunkerRegistry: Failed to register chunkers: {str(e)}")
import traceback
traceback.print_exc()
def register(self, typeGroup: str, chunker: Chunker):
self._map[typeGroup] = chunker
def resolve(self, typeGroup: str) -> Chunker:
return self._map.get(typeGroup, self._noop)