gateway/modules/connectors/connectorAiOpenai.py

270 lines
No EOL
10 KiB
Python

import logging
import base64
import httpx
from typing import Dict, Any, List, Union
from fastapi import HTTPException
from modules.shared.configuration import APP_CONFIG
# Configure logger
logger = logging.getLogger(__name__)
class ContextLengthExceededException(Exception):
"""Exception raised when the context length exceeds the model's limit"""
pass
def loadConfigData():
"""Load configuration data for OpenAI connector"""
return {
"apiKey": APP_CONFIG.get('Connector_AiOpenai_API_SECRET'),
"apiUrl": APP_CONFIG.get('Connector_AiOpenai_API_URL'),
"modelName": APP_CONFIG.get('Connector_AiOpenai_MODEL_NAME'),
"temperature": float(APP_CONFIG.get('Connector_AiOpenai_TEMPERATURE')),
"maxTokens": int(APP_CONFIG.get('Connector_AiOpenai_MAX_TOKENS'))
}
class AiOpenai:
"""Connector for communication with the OpenAI API."""
def __init__(self):
# Load configuration
self.config = loadConfigData()
self.apiKey = self.config["apiKey"]
self.apiUrl = self.config["apiUrl"]
self.modelName = self.config["modelName"]
# HttpClient for API calls
self.httpClient = httpx.AsyncClient(
timeout=120.0, # Longer timeout for complex requests
headers={
"Authorization": f"Bearer {self.apiKey}",
"Content-Type": "application/json"
}
)
logger.info(f"OpenAI Connector initialized with model: {self.modelName}")
async def callAiBasic(self, messages: List[Dict[str, Any]], temperature: float = None, maxTokens: int = None) -> str:
"""
Calls the OpenAI API with the given messages.
Args:
messages: List of messages in OpenAI format (role, content)
temperature: Temperature for response generation (0.0-1.0)
maxTokens: Maximum number of tokens in the response
Returns:
The response from the OpenAI API
Raises:
HTTPException: For errors in API communication
"""
try:
# Use parameters from configuration if none were overridden
if temperature is None:
temperature = self.config.get("temperature", 0.2)
if maxTokens is None:
maxTokens = self.config.get("maxTokens", 2000)
payload = {
"model": self.modelName,
"messages": messages,
"temperature": temperature,
"max_tokens": maxTokens
}
response = await self.httpClient.post(
self.apiUrl,
json=payload
)
if response.status_code != 200:
logger.error(f"OpenAI API error: {response.status_code} - {response.text}")
# Check for context length exceeded error
if response.status_code == 400:
try:
error_data = response.json()
if (error_data.get("error", {}).get("code") == "context_length_exceeded" or
"context length" in error_data.get("error", {}).get("message", "").lower()):
# Raise a specific exception for context length issues
raise ContextLengthExceededException(
f"Context length exceeded: {error_data.get('error', {}).get('message', 'Unknown error')}"
)
except (ValueError, KeyError):
pass # If we can't parse the error, fall through to generic error
raise HTTPException(status_code=500, detail="Error communicating with OpenAI API")
responseJson = response.json()
content = responseJson["choices"][0]["message"]["content"]
return content
except ContextLengthExceededException:
# Re-raise context length exceptions without wrapping
raise
except Exception as e:
logger.error(f"Error calling OpenAI API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
async def callAiImage(self, prompt: str, imageData: Union[str, bytes], mimeType: str = None) -> str:
"""
Analyzes an image with the OpenAI Vision API.
Args:
imageData: base64encoded data
mimeType: The MIME type of the image (optional, only for binary data)
prompt: The prompt for analysis
Returns:
The response from the OpenAI Vision API as text
"""
try:
logger.debug(f"Starting image analysis with query '{prompt}' for size {len(imageData)}B...")
# Ensure imageData is a string (base64 encoded)
if not isinstance(imageData, str):
raise ValueError("imageData must be a string (base64 encoded)")
# Fix base64 padding if needed
padding_needed = len(imageData) % 4
if padding_needed:
imageData += '=' * (4 - padding_needed)
# Use default MIME type if not provided
if not mimeType:
mimeType = "image/jpeg"
logger.debug(f"Using MIME type: {mimeType}")
logger.debug(f"Base64 data length: {len(imageData)} characters")
# Create the data URL format as required by OpenAI Vision API
data_url = f"data:{mimeType};base64,{imageData}"
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": data_url
}
}
]
}
]
# Use a vision-capable model for image analysis
# Override the model for vision tasks
visionModel = "gpt-4o" # or "gpt-4-vision-preview" depending on availability
# Use parameters from configuration
temperature = self.config.get("temperature", 0.2)
maxTokens = self.config.get("maxTokens", 2000)
payload = {
"model": visionModel,
"messages": messages,
"temperature": temperature,
"max_tokens": maxTokens
}
response = await self.httpClient.post(
self.apiUrl,
json=payload
)
if response.status_code != 200:
logger.error(f"OpenAI API error: {response.status_code} - {response.text}")
raise HTTPException(status_code=500, detail="Error communicating with OpenAI API")
responseJson = response.json()
content = responseJson["choices"][0]["message"]["content"]
return content
# Return content
return response
except Exception as e:
logger.error(f"Error during image analysis: {str(e)}", exc_info=True)
return f"[Error during image analysis: {str(e)}]"
async def generateImage(self, prompt: str, size: str = "1024x1024", quality: str = "standard", style: str = "vivid") -> Dict[str, Any]:
"""
Generate an image using DALL-E 3.
Args:
prompt: The text prompt for image generation
size: Image size (1024x1024, 1792x1024, or 1024x1792)
quality: Image quality (standard or hd)
style: Image style (vivid or natural)
Returns:
Dictionary with success status and image data
"""
try:
logger.debug(f"Starting image generation with prompt: '{prompt[:100]}...'")
# DALL-E 3 API endpoint
dalle_url = "https://api.openai.com/v1/images/generations"
payload = {
"model": "dall-e-3",
"prompt": prompt,
"size": size,
"quality": quality,
"style": style,
"n": 1,
"response_format": "b64_json" # Get base64 data directly instead of URLs
}
# Create a separate client for DALL-E API calls
dalle_client = httpx.AsyncClient(
timeout=120.0,
headers={
"Authorization": f"Bearer {self.apiKey}",
"Content-Type": "application/json"
}
)
response = await dalle_client.post(
dalle_url,
json=payload
)
await dalle_client.aclose()
if response.status_code != 200:
logger.error(f"DALL-E API error: {response.status_code} - {response.text}")
return {
"success": False,
"error": f"DALL-E API error: {response.status_code} - {response.text}"
}
responseJson = response.json()
if "data" in responseJson and len(responseJson["data"]) > 0:
image_data = responseJson["data"][0]["b64_json"]
logger.info(f"Successfully generated image: {len(image_data)} characters")
return {
"success": True,
"image_data": image_data,
"size": size,
"quality": quality,
"style": style
}
else:
logger.error("No image data in DALL-E response")
return {
"success": False,
"error": "No image data in DALL-E response"
}
except Exception as e:
logger.error(f"Error during image generation: {str(e)}", exc_info=True)
return {
"success": False,
"error": f"Error during image generation: {str(e)}"
}