gateway/modules/services/serviceAi/subSharedAiUtils.py
2025-10-24 23:57:17 +02:00

163 lines
5.6 KiB
Python

"""
Shared utilities for AI services to eliminate code duplication.
This module contains common functions used across multiple AI service modules
to maintain DRY principles and ensure consistency.
"""
import re
import logging
from typing import Dict, Any, List, Optional, Union
logger = logging.getLogger(__name__)
def buildPromptWithPlaceholders(prompt: str, placeholders: Optional[Dict[str, str]]) -> str:
"""
Build full prompt by replacing placeholders with their content.
Uses the new {{KEY:placeholder}} format.
Args:
prompt: The base prompt template
placeholders: Dictionary of placeholder key-value pairs
Returns:
Prompt with placeholders replaced
"""
if not placeholders:
return prompt
full_prompt = prompt
for placeholder, content in placeholders.items():
# Replace both old format {{placeholder}} and new format {{KEY:placeholder}}
full_prompt = full_prompt.replace(f"{{{{{placeholder}}}}}", content)
full_prompt = full_prompt.replace(f"{{{{KEY:{placeholder}}}}}", content)
return full_prompt
def sanitizePromptContent(content: str, contentType: str = "text") -> str:
"""
Centralized prompt content sanitization to prevent injection attacks and ensure safe presentation.
This is the single source of truth for all prompt sanitization across the system.
Replaces all scattered sanitization functions with a unified approach.
Args:
content: The content to sanitize
contentType: Type of content ("text", "userinput", "json", "document")
Returns:
Safely sanitized content ready for AI prompt insertion
"""
if not content:
return ""
try:
# Convert to string if not already
content_str = str(content)
# Remove null bytes and control characters (except newlines and tabs)
sanitized = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', content_str)
# Handle different content types with appropriate sanitization
if contentType == "userinput":
# Extra security for user-controlled content
# Escape curly braces to prevent placeholder injection
sanitized = sanitized.replace('{', '{{').replace('}', '}}')
# Escape quotes and wrap in single quotes
sanitized = sanitized.replace('"', '\\"').replace("'", "\\'")
return f"'{sanitized}'"
elif contentType == "json":
# For JSON content, escape quotes and backslashes
sanitized = sanitized.replace('\\', '\\\\')
sanitized = sanitized.replace('"', '\\"')
sanitized = sanitized.replace('\n', '\\n')
sanitized = sanitized.replace('\r', '\\r')
sanitized = sanitized.replace('\t', '\\t')
elif contentType == "document":
# For document content, escape special characters
sanitized = sanitized.replace('\\', '\\\\')
sanitized = sanitized.replace('"', '\\"')
sanitized = sanitized.replace("'", "\\'")
sanitized = sanitized.replace('\n', '\\n')
sanitized = sanitized.replace('\r', '\\r')
sanitized = sanitized.replace('\t', '\\t')
else: # contentType == "text" or default
# Basic text sanitization
sanitized = sanitized.replace('\\', '\\\\')
sanitized = sanitized.replace('"', '\\"')
sanitized = sanitized.replace("'", "\\'")
sanitized = sanitized.replace('\n', '\\n')
sanitized = sanitized.replace('\r', '\\r')
sanitized = sanitized.replace('\t', '\\t')
return sanitized
except Exception as e:
logger.error(f"Error sanitizing prompt content: {str(e)}")
# Return a safe fallback
return "[ERROR: Content could not be safely sanitized]"
def extractTextFromContentParts(extracted_content) -> str:
"""
Extract text content from ExtractionService ContentPart objects.
Args:
extracted_content: ContentExtracted object with parts
Returns:
Concatenated text content from all text/table/structure parts
"""
if not extracted_content or not hasattr(extracted_content, 'parts'):
return ""
text_parts = []
for part in extracted_content.parts:
if hasattr(part, 'typeGroup') and part.typeGroup in ['text', 'table', 'structure']:
if hasattr(part, 'data') and part.data:
text_parts.append(part.data)
return "\n\n".join(text_parts)
def reduceText(text: str, reduction_factor: float) -> str:
"""
Reduce text size by the specified factor.
Args:
text: Text to reduce
reduction_factor: Factor by which to reduce (0.0 to 1.0)
Returns:
Reduced text with truncation indicator
"""
if reduction_factor >= 1.0:
return text
target_length = int(len(text) * reduction_factor)
return text[:target_length] + "... [reduced]"
def determineCallType(documents: Optional[List], operation_type: str) -> str:
"""
Determine call type based on documents and operation type.
Args:
documents: List of ChatDocument objects
operation_type: Type of operation being performed
Returns:
Call type: "plan" or "text"
"""
has_documents = documents is not None and len(documents) > 0
is_planning_operation = operation_type == "plan"
if not has_documents and is_planning_operation:
return "plan"
else:
return "text"