gateway/modules/interfaces/interfaceAiObjects.py
2026-04-01 21:59:28 +02:00

678 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import logging
import asyncio
import uuid
import base64
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
from dataclasses import dataclass, field
import time
logger = logging.getLogger(__name__)
from modules.aicore.aicoreModelRegistry import modelRegistry
from modules.aicore.aicoreModelSelector import modelSelector
from modules.aicore.aicoreBase import RateLimitExceededException
from modules.datamodels.datamodelAi import (
AiModel,
AiCallOptions,
AiCallRequest,
AiCallResponse,
OperationTypeEnum,
AiModelCall,
AiModelResponse,
)
from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy
# Dynamic model registry - models are now loaded from connectors via aicore system
@dataclass(slots=True)
class AiObjects:
"""Centralized AI interface: dynamically discovers and uses AI models.
billingCallback: Set by serviceAi before AI calls. Called after EVERY individual
model call with the AiCallResponse. This ensures per-model-call billing with
exact provider + model name. The callback handles billing recording.
"""
billingCallback: Optional[Callable] = field(default=None, repr=False)
def __post_init__(self) -> None:
# Auto-discover and register all available connectors
self._discoverAndRegisterConnectors()
def _discoverAndRegisterConnectors(self):
"""Auto-discover and register all available AI connectors."""
logger.info("Auto-discovering AI connectors...")
# Use the model registry's built-in discovery mechanism
discoveredConnectors = modelRegistry.discoverConnectors()
# Register each discovered connector
for connector in discoveredConnectors:
modelRegistry.registerConnector(connector)
logger.info(f"Registered connector: {connector.getConnectorType()}")
logger.info(f"Total connectors registered: {len(discoveredConnectors)}")
logger.info("All AI connectors registered with dynamic model registry")
@classmethod
async def create(cls) -> "AiObjects":
"""Create AiObjects instance with auto-discovered connectors."""
# No need to manually create connectors - they're auto-discovered
return cls()
def _selectModel(self, prompt: str, context: str, options: AiCallOptions) -> str:
"""Select the best model using dynamic model selection system. Returns displayName (unique identifier)."""
# Get available models from the dynamic registry
availableModels = modelRegistry.getAvailableModels()
if not availableModels:
logger.error("No models available in the registry")
raise ValueError("No AI models available")
# Use the dynamic model selector
selectedModel = modelSelector.selectModel(prompt, context, options, availableModels)
if not selectedModel:
logger.error("No suitable model found for the given criteria")
raise ValueError("No suitable AI model found")
logger.info(f"Selected model: {selectedModel.name} ({selectedModel.displayName})")
return selectedModel.displayName
# AI for Extraction, Processing, Generation
async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
"""Call AI model for traditional text/context calls with fallback mechanism.
Supports two modes:
- Legacy: prompt + context → constructs messages internally
- Agent: request.messages provided → passes through directly
"""
prompt = request.prompt
context = request.context or ""
options = request.options
# Get failover models for this operation type
availableModels = modelRegistry.getAvailableModels()
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
if filteredModels:
availableModels = filteredModels
else:
errorMsg = f"No models match allowedProviders {allowedProviders} for operation {options.operationType}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg, modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1,
)
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
if not failoverModelList:
errorMsg = f"No suitable models found for operation {options.operationType}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg,
modelName="error",
priceCHF=0.0,
processingTime=0.0,
bytesSent=0,
bytesReceived=0,
errorCount=1
)
_MAX_SHORT_RETRY = 15.0
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
if request.messages:
response = await self._callWithMessages(model, request.messages, options, request.tools, toolChoice=request.toolChoice)
else:
response = await self._callWithModel(model, prompt, context, options)
logger.info(f"AI call successful with model: {model.name}")
return response
except RateLimitExceededException as rle:
retryAfter = rle.retryAfterSeconds
lastError = rle
if 0 < retryAfter <= _MAX_SHORT_RETRY:
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
await asyncio.sleep(retryAfter + 0.5)
try:
if request.messages:
response = await self._callWithMessages(model, request.messages, options, request.tools, toolChoice=request.toolChoice)
else:
response = await self._callWithModel(model, prompt, context, options)
logger.info(f"AI call successful with {model.name} after rate-limit retry")
return response
except Exception as retryErr:
lastError = retryErr
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
else:
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
if attempt < len(failoverModelList) - 1:
continue
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
break
except Exception as e:
lastError = e
logger.warning(f"AI call failed with model {model.name}: {str(e)}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
logger.info(f"Trying next failover model...")
continue
else:
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
break
# All failover attempts failed - return error response
errorMsg = f"All AI models failed for operation {options.operationType}. Last error: {str(lastError)}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg,
modelName="error",
priceCHF=0.0,
processingTime=0.0,
bytesSent=0,
bytesReceived=0,
errorCount=1
)
def _createErrorResponse(self, errorMsg: str, inputBytes: int, outputBytes: int) -> AiCallResponse:
"""Create an error response."""
return AiCallResponse(
content=errorMsg,
modelName="error",
priceCHF=0.0,
processingTime=0.0,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=1
)
async def _callWithModel(self, model: AiModel, prompt: str, context: str, options: AiCallOptions = None) -> AiCallResponse:
"""Call a specific model and return the response."""
# Calculate input bytes from prompt and context
inputBytes = len((prompt + context).encode('utf-8'))
# Replace <TOKEN_LIMIT> placeholder with model's maxTokens value
if "<TOKEN_LIMIT>" in prompt:
if model.maxTokens > 0:
tokenLimit = str(model.maxTokens)
modelPrompt = prompt.replace("<TOKEN_LIMIT>", tokenLimit)
logger.debug(f"Replaced <TOKEN_LIMIT> with {tokenLimit} for model {model.name}")
else:
raise ValueError(f"Model {model.name} has invalid maxTokens ({model.maxTokens}). Cannot set token limit.")
else:
modelPrompt = prompt
# Update messages array with replaced content
messages = []
if context:
messages.append({"role": "system", "content": f"Context from documents:\n{context}"})
messages.append({"role": "user", "content": modelPrompt})
# Start timing
startTime = time.time()
# Call the model's function directly - completely generic
if model.functionCall:
# Create standardized call object
modelCall = AiModelCall(
messages=messages,
model=model,
options=options or {}
)
# Log before calling model
contextSize = len(context.encode('utf-8')) if context else 0
promptSize = len(modelPrompt.encode('utf-8')) if modelPrompt else 0
totalInputSize = contextSize + promptSize
logger.debug(f"Calling model {model.name} with {len(messages)} messages, context size: {contextSize} bytes, prompt size: {promptSize} bytes, total input: {totalInputSize} bytes")
# Call the model with standardized interface
modelResponse = await model.functionCall(modelCall)
# Log after successful call
logger.debug(f"Model {model.name} returned successfully")
# Extract content from standardized response
if not modelResponse.success:
raise ValueError(f"Model call failed: {modelResponse.error}")
content = modelResponse.content
else:
raise ValueError(f"Model {model.name} has no function call defined")
# Calculate timing and output bytes
endTime = time.time()
processingTime = endTime - startTime
outputBytes = len(content.encode("utf-8"))
# Calculate price using model's own price calculation method
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0
)
# BILLING: Record billing for THIS specific model call
# billingCallback is set by serviceAi and records one billing transaction
# per model call with exact provider + model name
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
return response
async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]],
options: AiCallOptions = None,
tools: List[Dict[str, Any]] = None,
toolChoice: Any = None) -> AiCallResponse:
"""Call a model with pre-built messages (agent mode). Supports tools for native function calling."""
import json as _json
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
startTime = time.time()
if not model.functionCall:
raise ValueError(f"Model {model.name} has no function call defined")
modelCall = AiModelCall(
messages=messages,
model=model,
options=options or {},
tools=tools,
toolChoice=toolChoice,
)
modelResponse = await model.functionCall(modelCall)
if not modelResponse.success:
raise ValueError(f"Model call failed: {modelResponse.error}")
endTime = time.time()
processingTime = endTime - startTime
content = modelResponse.content
outputBytes = len(content.encode("utf-8"))
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
# Extract tool calls from metadata if present (native function calling)
responseToolCalls = None
if modelResponse.metadata:
responseToolCalls = modelResponse.metadata.get("toolCalls")
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0,
toolCalls=responseToolCalls
)
response._modelMaxTokens = model.maxTokens
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
return response
async def callWithTextContextStream(
self, request: AiCallRequest
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
"""Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse."""
options = request.options
availableModels = modelRegistry.getAvailableModels()
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
if filtered:
availableModels = filtered
else:
yield AiCallResponse(
content=f"No models match allowedProviders {allowedProviders} for operation {options.operationType}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
return
failoverModelList = modelSelector.getFailoverModelList(
request.prompt, request.context or "", options, availableModels
)
if not failoverModelList:
yield AiCallResponse(
content=f"No suitable models found for operation {options.operationType}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
return
_MAX_SHORT_RETRY = 15.0
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})")
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools, toolChoice=request.toolChoice):
yield chunk
return
except RateLimitExceededException as rle:
retryAfter = rle.retryAfterSeconds
lastError = rle
if 0 < retryAfter <= _MAX_SHORT_RETRY:
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
await asyncio.sleep(retryAfter + 0.5)
try:
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools, toolChoice=request.toolChoice):
yield chunk
return
except Exception as retryErr:
lastError = retryErr
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
else:
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
if attempt < len(failoverModelList) - 1:
continue
break
except Exception as e:
lastError = e
logger.warning(f"Streaming AI call failed with {model.name}: {e}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
continue
break
yield AiCallResponse(
content=f"All models failed (stream). Last error: {lastError}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
async def _callWithMessagesStream(
self, model: AiModel, messages: List[Dict[str, Any]],
options: AiCallOptions = None, tools: List[Dict[str, Any]] = None,
toolChoice: Any = None,
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
"""Stream a model call. Yields str deltas, then final AiCallResponse with billing."""
from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
startTime = time.time()
if not model.functionCallStream:
response = await self._callWithMessages(model, messages, options, tools, toolChoice=toolChoice)
if response.content:
yield response.content
yield response
return
modelCall = AiModelCall(
messages=messages, model=model,
options=options or {}, tools=tools,
toolChoice=toolChoice,
)
finalModelResponse = None
async for item in model.functionCallStream(modelCall):
if isinstance(item, AiModelResponse):
finalModelResponse = item
else:
yield item
if not finalModelResponse:
raise ValueError(f"Stream from {model.name} produced no final AiModelResponse")
endTime = time.time()
processingTime = endTime - startTime
content = finalModelResponse.content
outputBytes = len(content.encode("utf-8"))
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
responseToolCalls = None
if finalModelResponse.metadata:
responseToolCalls = finalModelResponse.metadata.get("toolCalls")
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0,
toolCalls=responseToolCalls,
)
response._modelMaxTokens = model.maxTokens
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}")
yield response
async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse:
"""Generate embeddings for a list of texts using the best available embedding model.
Token-aware batching: splits the texts list into batches that respect the
model's contextLength (with 10% safety margin). Each batch is sent as a
separate API call; the resulting embeddings are merged in order.
Failover across providers (OpenAI -> Mistral) works identically to chat models,
but ContextLengthExceededException is NOT retried via failover (same limits).
Returns:
AiCallResponse with metadata["embeddings"] containing the vectors.
"""
from modules.aicore.aicoreBase import ContextLengthExceededException as _CtxExc
if options is None:
options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING)
else:
options.operationType = OperationTypeEnum.EMBEDDING
combinedText = " ".join(texts[:3])[:500]
availableModels = modelRegistry.getAvailableModels()
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
if filtered:
availableModels = filtered
else:
logger.warning(f"No embedding models match allowedProviders {allowedProviders}")
failoverModelList = modelSelector.getFailoverModelList(
combinedText, "", options, availableModels
)
if not failoverModelList:
return AiCallResponse(
content="", modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
startTime = time.time()
batches = _buildEmbeddingBatches(texts, model.contextLength)
logger.info(
f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), "
f"model contextLength={model.contextLength}"
)
allEmbeddings: List[List[float]] = []
totalPriceCHF = 0.0
for batchIdx, batch in enumerate(batches):
modelCall = AiModelCall(
model=model, options=options, embeddingInput=batch
)
modelResponse = await model.functionCall(modelCall)
if not modelResponse.success:
raise ValueError(f"Embedding batch {batchIdx + 1} failed: {modelResponse.error}")
batchEmbeddings = (modelResponse.metadata or {}).get("embeddings", [])
allEmbeddings.extend(batchEmbeddings)
batchBytes = sum(len(t.encode("utf-8")) for t in batch)
totalPriceCHF += model.calculatepriceCHF(0, batchBytes, 0)
processingTime = time.time() - startTime
if totalPriceCHF == 0.0:
totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
response = AiCallResponse(
content="", modelName=model.name, provider=model.connectorType,
priceCHF=totalPriceCHF, processingTime=processingTime,
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
metadata={"embeddings": allEmbeddings}
)
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}")
return response
except _CtxExc as e:
logger.error(f"ContextLengthExceeded for {model.name} despite batching aborting failover: {e}")
return AiCallResponse(
content=str(e), modelName=model.name, priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
except RateLimitExceededException as rle:
retryAfter = rle.retryAfterSeconds
lastError = rle
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
logger.warning(f"Rate limit on {model.name} during embedding (retryAfter={retryAfter:.1f}s)")
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
if attempt < len(failoverModelList) - 1:
continue
break
except Exception as e:
lastError = e
logger.warning(f"Embedding call failed with {model.name}: {str(e)}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
continue
break
errorMsg = f"All embedding models failed. Last error: {str(lastError)}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg, modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
# Utility methods
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
"""List available models, optionally filtered by connector type."""
models = modelRegistry.getAvailableModels()
if connectorType:
return [model.model_dump() for model in models if model.connectorType == connectorType]
return [model.model_dump() for model in models]
async def getModelInfo(self, displayName: str) -> Dict[str, Any]:
"""Get information about a specific model by displayName."""
model = modelRegistry.getModel(displayName)
if not model:
raise ValueError(f"Model with displayName '{displayName}' not found")
return model.model_dump()
async def getModelsByTag(self, tag: str) -> List[str]:
"""Get model displayNames that have a specific tag. Returns displayNames (unique identifiers)."""
models = modelRegistry.getModelsByTag(tag)
return [model.displayName for model in models]
# =============================================================================
# Internal helpers
# =============================================================================
_CHARS_PER_TOKEN = 4
_SAFETY_MARGIN = 0.90
def _estimateTokens(text: str) -> int:
"""Rough token estimate: 1 token ~ 4 characters."""
return max(1, len(text) // _CHARS_PER_TOKEN)
def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[str]]:
"""Split a list of texts into batches whose total estimated token count
stays within the model's contextLength (with safety margin).
Each individual text is assumed to already be within limits (enforced by
the chunking layer). If a single text exceeds the budget, it is placed
in its own batch as a last resort.
"""
if not texts:
return []
if contextLength <= 0:
return [texts]
maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN)
batches: List[List[str]] = []
currentBatch: List[str] = []
currentTokens = 0
for text in texts:
textTokens = _estimateTokens(text)
if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch:
batches.append(currentBatch)
currentBatch = []
currentTokens = 0
currentBatch.append(text)
currentTokens += textTokens
if currentBatch:
batches.append(currentBatch)
return batches