gateway/modules/aicore/aicorePluginMistral.py
2026-03-18 13:57:01 +01:00

449 lines
20 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import logging
import json as _json
import httpx
from typing import List, Dict, Any, AsyncGenerator, Union
from fastapi import HTTPException
from modules.shared.configuration import APP_CONFIG
from .aicoreBase import BaseConnectorAi, RateLimitExceededException, ContextLengthExceededException
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
logger = logging.getLogger(__name__)
def loadConfigData():
"""Load configuration data for Mistral connector"""
return {
"apiKey": APP_CONFIG.get('Connector_AiMistral_API_SECRET'),
}
class AiMistral(BaseConnectorAi):
"""Connector for communication with the Mistral AI API (Le Chat Mistral)."""
def __init__(self):
super().__init__()
# Load configuration
self.config = loadConfigData()
self.apiKey = self.config["apiKey"]
# HttpClient for API calls
# Timeout set to 600 seconds (10 minutes) for complex requests that may take longer
# AiService calls can take significantly longer due to prompt building and processing overhead
self.httpClient = httpx.AsyncClient(
timeout=600.0,
headers={
"Authorization": f"Bearer {self.apiKey}",
"Content-Type": "application/json"
}
)
logger.info("Mistral Connector initialized")
def getConnectorType(self) -> str:
"""Get the connector type identifier."""
return "mistral"
def getModels(self) -> List[AiModel]:
"""Get all available Mistral models."""
return [
AiModel(
name="mistral-large-latest",
displayName="Mistral Large 3",
connectorType="mistral",
apiUrl="https://api.mistral.ai/v1/chat/completions",
temperature=0.2,
maxTokens=16384,
contextLength=256000,
costPer1kTokensInput=0.0005, # $0.50/M tokens (updated 2026-02)
costPer1kTokensOutput=0.0015, # $1.50/M tokens (updated 2026-02)
speedRating=8, # Good speed for complex tasks
qualityRating=9, # High quality
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.BALANCED,
processingMode=ProcessingModeEnum.ADVANCED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 9),
(OperationTypeEnum.DATA_ANALYSE, 9),
(OperationTypeEnum.DATA_GENERATE, 9),
(OperationTypeEnum.DATA_EXTRACT, 8),
(OperationTypeEnum.AGENT, 8),
),
version="mistral-large-latest",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0005 + (bytesReceived / 4 / 1000) * 0.0015
),
AiModel(
name="mistral-small-latest",
displayName="Mistral Small 3.2",
connectorType="mistral",
apiUrl="https://api.mistral.ai/v1/chat/completions",
temperature=0.2,
maxTokens=16384,
contextLength=128000,
costPer1kTokensInput=0.00006, # $0.06/M tokens (updated 2026-02)
costPer1kTokensOutput=0.00018, # $0.18/M tokens (updated 2026-02)
speedRating=9, # Very fast, lightweight model
qualityRating=7, # Good quality, cost-efficient
functionCall=self.callAiBasic,
functionCallStream=self.callAiBasicStream,
priority=PriorityEnum.SPEED,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.PLAN, 7),
(OperationTypeEnum.DATA_ANALYSE, 7),
(OperationTypeEnum.DATA_GENERATE, 8),
(OperationTypeEnum.DATA_EXTRACT, 7),
(OperationTypeEnum.AGENT, 6),
),
version="mistral-small-latest",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00006 + (bytesReceived / 4 / 1000) * 0.00018
),
AiModel(
name="mistral-embed",
displayName="Mistral Embed",
connectorType="mistral",
apiUrl="https://api.mistral.ai/v1/embeddings",
temperature=0.0,
maxTokens=0,
contextLength=8192,
costPer1kTokensInput=0.0001, # $0.10/M tokens
costPer1kTokensOutput=0.0,
speedRating=10,
qualityRating=7,
functionCall=self.callEmbedding,
priority=PriorityEnum.COST,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.EMBEDDING, 8)
),
version="mistral-embed",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0001
),
AiModel(
name="mistral-large-latest",
displayName="Mistral Large 3 Vision",
connectorType="mistral",
apiUrl="https://api.mistral.ai/v1/chat/completions",
temperature=0.2,
maxTokens=16384,
contextLength=256000,
costPer1kTokensInput=0.0005, # $0.50/M tokens (updated 2026-02)
costPer1kTokensOutput=0.0015, # $1.50/M tokens (updated 2026-02)
speedRating=6, # Slower for vision tasks
qualityRating=8, # Good quality vision
functionCall=self.callAiImage,
priority=PriorityEnum.QUALITY,
processingMode=ProcessingModeEnum.DETAILED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.IMAGE_ANALYSE, 8)
),
version="mistral-large-latest",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0005 + (bytesReceived / 4 / 1000) * 0.0015
)
]
async def callAiBasic(self, modelCall: AiModelCall) -> AiModelResponse:
"""
Calls the Mistral AI API with the given messages using standardized pattern.
Mistral's chat completions API is OpenAI-compatible: it accepts the same
message format (role/content) including system messages, and returns
responses in the same choices[0].message.content structure.
Args:
modelCall: AiModelCall with messages and options
Returns:
AiModelResponse with content and metadata
Raises:
HTTPException: For errors in API communication
"""
try:
# Extract parameters from modelCall
messages = modelCall.messages
model = modelCall.model
options = modelCall.options
temperature = getattr(options, "temperature", None)
if temperature is None:
temperature = model.temperature
maxTokens = model.maxTokens
payload = {
"model": model.name,
"messages": messages,
"temperature": temperature,
"max_tokens": maxTokens
}
if modelCall.tools:
payload["tools"] = modelCall.tools
payload["tool_choice"] = modelCall.toolChoice or "auto"
response = await self.httpClient.post(
model.apiUrl,
json=payload
)
if response.status_code != 200:
error_message = f"Mistral API error: {response.status_code} - {response.text}"
logger.error(error_message)
# Check for rate limit exceeded (429 TPM)
if response.status_code == 429:
try:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", "Rate limit exceeded")
raise RateLimitExceededException(
f"Rate limit exceeded for {model.name}: {error_msg}"
)
except (ValueError, KeyError):
raise RateLimitExceededException(
f"Rate limit exceeded for {model.name}"
)
# Check for context length exceeded error
if response.status_code == 400:
try:
error_data = response.json()
if (error_data.get("error", {}).get("code") == "context_length_exceeded" or
"context length" in error_data.get("error", {}).get("message", "").lower() or
"too many tokens" in error_data.get("error", {}).get("message", "").lower()):
raise ContextLengthExceededException(
f"Context length exceeded: {error_data.get('error', {}).get('message', 'Unknown error')}"
)
except (ValueError, KeyError):
pass # If we can't parse the error, fall through to generic error
# Include the actual error details in the exception
raise HTTPException(status_code=500, detail=error_message)
responseJson = response.json()
choiceMessage = responseJson["choices"][0]["message"]
content = choiceMessage.get("content") or ""
metadata = {"response_id": responseJson.get("id", "")}
if choiceMessage.get("tool_calls"):
metadata["toolCalls"] = choiceMessage["tool_calls"]
return AiModelResponse(
content=content,
success=True,
modelId=model.name,
metadata=metadata,
)
except ContextLengthExceededException:
# Re-raise context length exceptions without wrapping
raise
except RateLimitExceededException:
# Re-raise rate limit exceptions without wrapping
raise
except Exception as e:
logger.error(f"Error calling Mistral API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error calling Mistral API: {str(e)}")
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
"""Stream Mistral response. Yields str deltas, then final AiModelResponse."""
try:
model = modelCall.model
options = modelCall.options
temperature = getattr(options, "temperature", None)
if temperature is None:
temperature = model.temperature
payload: Dict[str, Any] = {
"model": model.name,
"messages": modelCall.messages,
"temperature": temperature,
"max_tokens": model.maxTokens,
"stream": True,
}
if modelCall.tools:
payload["tools"] = modelCall.tools
payload["tool_choice"] = modelCall.toolChoice or "auto"
fullContent = ""
toolCallsAccum: Dict[int, Dict[str, Any]] = {}
async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response:
if response.status_code != 200:
body = await response.aread()
bodyStr = body.decode()
if response.status_code == 429:
try:
errorMsg = _json.loads(bodyStr).get("error", {}).get("message", "Rate limit exceeded")
except (ValueError, KeyError):
errorMsg = f"Rate limit exceeded for {model.name}"
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}: {errorMsg}")
raise HTTPException(status_code=500, detail=f"Mistral stream error: {response.status_code} - {bodyStr}")
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = _json.loads(data)
except _json.JSONDecodeError:
continue
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "content" in delta and delta["content"]:
fullContent += delta["content"]
yield delta["content"]
for tcDelta in delta.get("tool_calls", []):
idx = tcDelta.get("index", 0)
if idx not in toolCallsAccum:
toolCallsAccum[idx] = {
"id": tcDelta.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""},
}
if tcDelta.get("id"):
toolCallsAccum[idx]["id"] = tcDelta["id"]
fn = tcDelta.get("function", {})
if fn.get("name"):
toolCallsAccum[idx]["function"]["name"] = fn["name"]
if fn.get("arguments"):
toolCallsAccum[idx]["function"]["arguments"] += fn["arguments"]
metadata: Dict[str, Any] = {}
if toolCallsAccum:
metadata["toolCalls"] = [toolCallsAccum[i] for i in sorted(toolCallsAccum)]
yield AiModelResponse(
content=fullContent,
success=True,
modelId=model.name,
metadata=metadata,
)
except (RateLimitExceededException, ContextLengthExceededException, HTTPException):
raise
except Exception as e:
logger.error(f"Error streaming Mistral API: {e}")
raise HTTPException(status_code=500, detail=f"Error streaming Mistral API: {e}")
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
"""Generate embeddings via the Mistral Embeddings API.
Reads texts from modelCall.embeddingInput.
Returns vectors in metadata["embeddings"].
"""
try:
model = modelCall.model
texts = modelCall.embeddingInput or []
if not texts:
return AiModelResponse(
content="", success=False, error="No embeddingInput provided"
)
payload = {"model": model.name, "input": texts}
response = await self.httpClient.post(model.apiUrl, json=payload)
if response.status_code != 200:
errorMessage = f"Mistral Embedding API error: {response.status_code} - {response.text}"
logger.error(errorMessage)
if response.status_code == 429:
raise RateLimitExceededException(f"Rate limit exceeded for {model.name}")
if response.status_code == 400:
try:
errorData = response.json()
errMsg = errorData.get("error", {}).get("message", "").lower()
errCode = errorData.get("error", {}).get("code", "")
if errCode == "context_length_exceeded" or "too many tokens" in errMsg or "maximum context length" in errMsg:
raise ContextLengthExceededException(
f"Embedding context length exceeded for {model.name}: {errorData.get('error', {}).get('message', '')}"
)
except (ValueError, KeyError):
pass
raise HTTPException(status_code=500, detail=errorMessage)
responseJson = response.json()
embeddings = [item["embedding"] for item in responseJson["data"]]
usage = responseJson.get("usage", {})
return AiModelResponse(
content="",
success=True,
modelId=model.name,
tokensUsed={
"input": usage.get("prompt_tokens", 0),
"output": 0,
"total": usage.get("total_tokens", 0),
},
metadata={"embeddings": embeddings},
)
except (RateLimitExceededException, ContextLengthExceededException):
raise
except Exception as e:
logger.error(f"Error calling Mistral Embedding API: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error calling Mistral Embedding API: {str(e)}")
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
"""
Analyzes an image with the Mistral Vision API using standardized pattern.
Mistral Large 3 is multimodal and accepts image inputs in OpenAI-compatible
format: {"type": "image_url", "image_url": {"url": "data:...base64,..."}}
Args:
modelCall: AiModelCall with messages and image data in options
Returns:
AiModelResponse with analysis content
"""
try:
# Extract parameters from modelCall
messages = modelCall.messages
model = modelCall.model
# Messages should already be in the correct format with image data embedded
# Just verify they contain image data
if not messages or not messages[0].get("content"):
raise ValueError("No messages provided for image analysis")
logger.debug(f"Starting image analysis with {len(messages)} message(s)...")
# Use the messages directly - they should already contain the image data
# in the format: {"type": "image_url", "image_url": {"url": "data:...base64,..."}}
# Mistral Large 3 supports this OpenAI-compatible vision format natively
# Use parameters from model
temperature = model.temperature
payload = {
"model": model.name,
"messages": messages,
"temperature": temperature
}
response = await self.httpClient.post(
model.apiUrl,
json=payload
)
if response.status_code != 200:
logger.error(f"Mistral API error: {response.status_code} - {response.text}")
raise HTTPException(status_code=500, detail="Error communicating with Mistral API")
responseJson = response.json()
content = responseJson["choices"][0]["message"]["content"]
return AiModelResponse(
content=content,
success=True,
modelId=model.name,
metadata={"response_id": responseJson.get("id", "")}
)
except Exception as e:
logger.error(f"Error during image analysis: {str(e)}", exc_info=True)
return AiModelResponse(
content="",
success=False,
error=f"Error during image analysis: {str(e)}"
)