163 lines
5.6 KiB
Python
163 lines
5.6 KiB
Python
"""
|
|
Shared utilities for AI services to eliminate code duplication.
|
|
|
|
This module contains common functions used across multiple AI service modules
|
|
to maintain DRY principles and ensure consistency.
|
|
"""
|
|
|
|
import re
|
|
import logging
|
|
from typing import Dict, Any, List, Optional, Union
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def buildPromptWithPlaceholders(prompt: str, placeholders: Optional[Dict[str, str]]) -> str:
|
|
"""
|
|
Build full prompt by replacing placeholders with their content.
|
|
Uses the new {{KEY:placeholder}} format.
|
|
|
|
Args:
|
|
prompt: The base prompt template
|
|
placeholders: Dictionary of placeholder key-value pairs
|
|
|
|
Returns:
|
|
Prompt with placeholders replaced
|
|
"""
|
|
if not placeholders:
|
|
return prompt
|
|
|
|
full_prompt = prompt
|
|
for placeholder, content in placeholders.items():
|
|
# Replace both old format {{placeholder}} and new format {{KEY:placeholder}}
|
|
full_prompt = full_prompt.replace(f"{{{{{placeholder}}}}}", content)
|
|
full_prompt = full_prompt.replace(f"{{{{KEY:{placeholder}}}}}", content)
|
|
|
|
return full_prompt
|
|
|
|
|
|
def sanitizePromptContent(content: str, contentType: str = "text") -> str:
|
|
"""
|
|
Centralized prompt content sanitization to prevent injection attacks and ensure safe presentation.
|
|
|
|
This is the single source of truth for all prompt sanitization across the system.
|
|
Replaces all scattered sanitization functions with a unified approach.
|
|
|
|
Args:
|
|
content: The content to sanitize
|
|
contentType: Type of content ("text", "userinput", "json", "document")
|
|
|
|
Returns:
|
|
Safely sanitized content ready for AI prompt insertion
|
|
"""
|
|
if not content:
|
|
return ""
|
|
|
|
try:
|
|
# Convert to string if not already
|
|
content_str = str(content)
|
|
|
|
# Remove null bytes and control characters (except newlines and tabs)
|
|
sanitized = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', content_str)
|
|
|
|
# Handle different content types with appropriate sanitization
|
|
if contentType == "userinput":
|
|
# Extra security for user-controlled content
|
|
# Escape curly braces to prevent placeholder injection
|
|
sanitized = sanitized.replace('{', '{{').replace('}', '}}')
|
|
# Escape quotes and wrap in single quotes
|
|
sanitized = sanitized.replace('"', '\\"').replace("'", "\\'")
|
|
return f"'{sanitized}'"
|
|
|
|
elif contentType == "json":
|
|
# For JSON content, escape quotes and backslashes
|
|
sanitized = sanitized.replace('\\', '\\\\')
|
|
sanitized = sanitized.replace('"', '\\"')
|
|
sanitized = sanitized.replace('\n', '\\n')
|
|
sanitized = sanitized.replace('\r', '\\r')
|
|
sanitized = sanitized.replace('\t', '\\t')
|
|
|
|
elif contentType == "document":
|
|
# For document content, escape special characters
|
|
sanitized = sanitized.replace('\\', '\\\\')
|
|
sanitized = sanitized.replace('"', '\\"')
|
|
sanitized = sanitized.replace("'", "\\'")
|
|
sanitized = sanitized.replace('\n', '\\n')
|
|
sanitized = sanitized.replace('\r', '\\r')
|
|
sanitized = sanitized.replace('\t', '\\t')
|
|
|
|
else: # contentType == "text" or default
|
|
# Basic text sanitization
|
|
sanitized = sanitized.replace('\\', '\\\\')
|
|
sanitized = sanitized.replace('"', '\\"')
|
|
sanitized = sanitized.replace("'", "\\'")
|
|
sanitized = sanitized.replace('\n', '\\n')
|
|
sanitized = sanitized.replace('\r', '\\r')
|
|
sanitized = sanitized.replace('\t', '\\t')
|
|
|
|
return sanitized
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sanitizing prompt content: {str(e)}")
|
|
# Return a safe fallback
|
|
return "[ERROR: Content could not be safely sanitized]"
|
|
|
|
|
|
def extractTextFromContentParts(extracted_content) -> str:
|
|
"""
|
|
Extract text content from ExtractionService ContentPart objects.
|
|
|
|
Args:
|
|
extracted_content: ContentExtracted object with parts
|
|
|
|
Returns:
|
|
Concatenated text content from all text/table/structure parts
|
|
"""
|
|
if not extracted_content or not hasattr(extracted_content, 'parts'):
|
|
return ""
|
|
|
|
text_parts = []
|
|
for part in extracted_content.parts:
|
|
if hasattr(part, 'typeGroup') and part.typeGroup in ['text', 'table', 'structure']:
|
|
if hasattr(part, 'data') and part.data:
|
|
text_parts.append(part.data)
|
|
|
|
return "\n\n".join(text_parts)
|
|
|
|
|
|
def reduceText(text: str, reduction_factor: float) -> str:
|
|
"""
|
|
Reduce text size by the specified factor.
|
|
|
|
Args:
|
|
text: Text to reduce
|
|
reduction_factor: Factor by which to reduce (0.0 to 1.0)
|
|
|
|
Returns:
|
|
Reduced text with truncation indicator
|
|
"""
|
|
if reduction_factor >= 1.0:
|
|
return text
|
|
|
|
target_length = int(len(text) * reduction_factor)
|
|
return text[:target_length] + "... [reduced]"
|
|
|
|
|
|
def determineCallType(documents: Optional[List], operation_type: str) -> str:
|
|
"""
|
|
Determine call type based on documents and operation type.
|
|
|
|
Args:
|
|
documents: List of ChatDocument objects
|
|
operation_type: Type of operation being performed
|
|
|
|
Returns:
|
|
Call type: "plan" or "text"
|
|
"""
|
|
has_documents = documents is not None and len(documents) > 0
|
|
is_planning_operation = operation_type == "plan"
|
|
|
|
if not has_documents and is_planning_operation:
|
|
return "plan"
|
|
else:
|
|
return "text"
|