132 lines
4.5 KiB
Python
132 lines
4.5 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Image Generation Path
|
|
|
|
Handles image generation with support for single and batch generation.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from typing import List, Optional
|
|
from modules.datamodels.datamodelWorkflow import AiResponse, AiResponseMetadata, DocumentData
|
|
from modules.datamodels.datamodelAi import AiCallOptions, OperationTypeEnum, AiCallRequest
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ImageGenerationPath:
|
|
"""Image generation path."""
|
|
|
|
def __init__(self, services):
|
|
self.services = services
|
|
|
|
async def generateImages(
|
|
self,
|
|
userPrompt: str,
|
|
count: int = 1,
|
|
style: Optional[str] = None,
|
|
format: str = "png",
|
|
title: Optional[str] = None,
|
|
parentOperationId: Optional[str] = None
|
|
) -> AiResponse:
|
|
"""
|
|
Generate image files.
|
|
|
|
Returns: AiResponse with image files as documents
|
|
"""
|
|
# Create operation ID
|
|
workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
|
|
imageOperationId = f"image_gen_{workflowId}_{int(time.time())}"
|
|
|
|
# Start progress tracking
|
|
self.services.chat.progressLogStart(
|
|
imageOperationId,
|
|
"Image Generation",
|
|
"Image Generation",
|
|
f"Format: {format}",
|
|
parentOperationId=parentOperationId
|
|
)
|
|
|
|
try:
|
|
self.services.chat.progressLogUpdate(imageOperationId, 0.4, "Calling AI for image generation")
|
|
|
|
# Build prompt with style if provided
|
|
imagePrompt = userPrompt
|
|
if style:
|
|
imagePrompt = f"{userPrompt}\n\nStyle: {style}"
|
|
|
|
# Use IMAGE_GENERATE operation
|
|
options = AiCallOptions(
|
|
operationType=OperationTypeEnum.IMAGE_GENERATE,
|
|
resultFormat=format
|
|
)
|
|
|
|
request = AiCallRequest(
|
|
prompt=imagePrompt,
|
|
context="",
|
|
options=options
|
|
)
|
|
|
|
response = await self.services.ai.callAi(request)
|
|
|
|
if not response.content:
|
|
errorMsg = f"No image data returned: {response.content}"
|
|
logger.error(f"Error in AI image generation: {errorMsg}")
|
|
self.services.chat.progressLogFinish(imageOperationId, False)
|
|
raise ValueError(errorMsg)
|
|
|
|
# Handle response content (could be base64 string or bytes)
|
|
imageData = response.content
|
|
if isinstance(imageData, str):
|
|
# Assume base64 encoded string
|
|
import base64
|
|
try:
|
|
imageData = base64.b64decode(imageData)
|
|
except Exception:
|
|
# If not base64, try encoding as bytes
|
|
imageData = imageData.encode('utf-8')
|
|
elif not isinstance(imageData, bytes):
|
|
imageData = bytes(imageData)
|
|
|
|
# Create document
|
|
imageDoc = DocumentData(
|
|
documentName=f"generated_image.{format}",
|
|
documentData=imageData,
|
|
mimeType=f"image/{format}"
|
|
)
|
|
|
|
metadata = AiResponseMetadata(
|
|
title=title or "Generated Image",
|
|
operationType=OperationTypeEnum.IMAGE_GENERATE.value
|
|
)
|
|
|
|
self.services.chat.storeWorkflowStat(
|
|
self.services.workflow,
|
|
response,
|
|
"ai.generate.image"
|
|
)
|
|
|
|
self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated")
|
|
self.services.chat.progressLogFinish(imageOperationId, True)
|
|
|
|
# Create content string describing the image generation
|
|
import json
|
|
contentJson = json.dumps({
|
|
"type": "image",
|
|
"format": format,
|
|
"prompt": userPrompt,
|
|
"filename": imageDoc.documentName
|
|
}, ensure_ascii=False)
|
|
|
|
return AiResponse(
|
|
content=contentJson, # JSON string describing the image generation
|
|
metadata=metadata,
|
|
documents=[imageDoc]
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in image generation: {str(e)}")
|
|
self.services.chat.progressLogFinish(imageOperationId, False)
|
|
raise
|
|
|